NikitaSrivatsan commited on
Commit
48ac659
1 Parent(s): 3d5b800

First pass at captioning functionality through web app

Browse files
Files changed (10) hide show
  1. .gitignore +1 -0
  2. app.py +2 -1
  3. audiocaptioner.py +68 -0
  4. audiostock-train-240k.txt +0 -0
  5. clipcap.py +405 -0
  6. data_module.py +382 -0
  7. dupes.pkl +3 -0
  8. infer.py +55 -0
  9. lib.py +19 -0
  10. utils.py +45 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import gradio as gr
 
2
 
3
  def greet(name):
4
  return f'Hello {name}!!'
5
 
6
- demo = gr.Interface(fn=greet,
7
  inputs=gr.Audio(sources='upload', type='filepath'),
8
  outputs='text')
9
  demo.launch()
 
1
  import gradio as gr
2
+ from infer import infer
3
 
4
  def greet(name):
5
  return f'Hello {name}!!'
6
 
7
+ demo = gr.Interface(fn=infer,
8
  inputs=gr.Audio(sources='upload', type='filepath'),
9
  outputs='text')
10
  demo.launch()
audiocaptioner.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib import *
2
+
3
+ import contextlib
4
+ import io
5
+ import laion_clap
6
+ import torch
7
+
8
+ class AudioCaptioner(torch.nn.Module):
9
+
10
+ def get_dummy_token(self, batch_size: int) -> torch.Tensor:
11
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64)
12
+
13
+ def embed_waveform(self, waveform):
14
+ # compute the prefix
15
+ input_dict = {
16
+ 'waveform': waveform # you can add more key-values
17
+ }
18
+ audio_embeds = self.clap_model.model.encode_audio(
19
+ input_dict,
20
+ device=waveform.device
21
+ )
22
+
23
+ # get BxD-dim embedding (last layer) D = 1024 -> 512 after audio projection
24
+ audio_embedding = torch.nn.functional.normalize(self.clap_model.model.audio_projection(audio_embeds['embedding']), dim=-1)
25
+ return audio_embedding
26
+
27
+ def create_prefix(self, waveform, batch_size):
28
+ if waveform is not None:
29
+ audio_embedding = self.embed_waveform(waveform)
30
+ else:
31
+ audio_embedding = torch.zeros(batch_size, self.prefix_size).cuda()
32
+ # project the prefix through map net and append it
33
+ prefix_projections = self.clip_project(audio_embedding).view(-1, self.prefix_length, self.gpt_embedding_size)
34
+ return prefix_projections
35
+
36
+ def forward(self, tokens: torch.Tensor, waveform: torch.Tensor, mask: Optional[torch.Tensor] = None,
37
+ labels: Optional[torch.Tensor] = None, freeze_gpt = False):
38
+ # embed the text
39
+ embedding_text = self.gpt.transformer.wte(tokens)
40
+ prefix_projections = self.create_prefix(waveform, tokens.shape[0])
41
+ embedding_text = torch.cat((prefix_projections, embedding_text), dim=1)
42
+ # offset labels
43
+ if labels is not None:
44
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
45
+ labels = torch.cat((dummy_token, tokens), dim=1)
46
+ # push through GPT
47
+ if freeze_gpt:
48
+ with torch.no_grad():
49
+ out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask)
50
+ else:
51
+ out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask)
52
+ return out
53
+
54
+ def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
55
+ num_layers: int = 8):
56
+ super(AudioCaptioner, self).__init__()
57
+ self.prefix_size = prefix_size
58
+ self.prefix_length = prefix_length
59
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
60
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
61
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
62
+ self.gpt_embedding_size * prefix_length))
63
+ self.clap_model = laion_clap.CLAP_Module(
64
+ enable_fusion=False,
65
+ amodel = 'HTSAT-base'
66
+ )
67
+ with contextlib.redirect_stdout(io.StringIO()):
68
+ self.clap_model.load_ckpt(ckpt = '/graft1/datasets/kechen/clap_ckpt/music_audioset_epoch_15_esc_90.14.pt')
audiostock-train-240k.txt ADDED
The diff for this file is too large to render. See raw diff
 
clipcap.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####################################################################
2
+ ### Credit: Ron Mokady / rmokady ###
3
+ ### Original Repo: https://github.com/rmokady/CLIP_prefix_caption ###
4
+ #####################################################################
5
+
6
+ from enum import Enum
7
+ from collections import defaultdict
8
+ import os
9
+ from torch import nn
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as nnf
13
+ import sys
14
+ from typing import Tuple, List, Union, Optional
15
+ from transformers import (
16
+ GPT2Tokenizer,
17
+ GPT2LMHeadModel,
18
+ AdamW,
19
+ get_linear_schedule_with_warmup,
20
+ )
21
+
22
+ # import torch
23
+
24
+ N = type(None)
25
+ V = np.array
26
+ ARRAY = np.ndarray
27
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
28
+ VS = Union[Tuple[V, ...], List[V]]
29
+ VN = Union[V, N]
30
+ VNS = Union[VS, N]
31
+ T = torch.Tensor
32
+ TS = Union[Tuple[T, ...], List[T]]
33
+ TN = Optional[T]
34
+ TNS = Union[Tuple[TN, ...], List[TN]]
35
+ TSN = Optional[TS]
36
+ TA = Union[T, ARRAY]
37
+
38
+ WEIGHTS_PATHS = {
39
+ "coco": "coco_weights.pt",
40
+ "conceptual-captions": "conceptual_weights.pt",
41
+ }
42
+
43
+ class MappingType(Enum):
44
+ MLP = 'mlp'
45
+ Transformer = 'transformer'
46
+
47
+ class MLP(nn.Module):
48
+ def forward(self, x: T) -> T:
49
+ return self.model(x)
50
+
51
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
52
+ super(MLP, self).__init__()
53
+ layers = []
54
+ for i in range(len(sizes) - 1):
55
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
56
+ if i < len(sizes) - 2:
57
+ layers.append(act())
58
+ self.model = nn.Sequential(*layers)
59
+
60
+ class MlpTransformer(nn.Module):
61
+ def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
62
+ super().__init__()
63
+ out_d = out_d if out_d is not None else in_dim
64
+ self.fc1 = nn.Linear(in_dim, h_dim)
65
+ self.act = act
66
+ self.fc2 = nn.Linear(h_dim, out_d)
67
+ self.dropout = nn.Dropout(dropout)
68
+
69
+ def forward(self, x):
70
+ x = self.fc1(x)
71
+ x = self.act(x)
72
+ x = self.dropout(x)
73
+ x = self.fc2(x)
74
+ x = self.dropout(x)
75
+ return x
76
+
77
+ class MultiHeadAttention(nn.Module):
78
+
79
+ def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
80
+ super().__init__()
81
+ self.num_heads = num_heads
82
+ head_dim = dim_self // num_heads
83
+ self.scale = head_dim ** -0.5
84
+ self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
85
+ self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
86
+ self.project = nn.Linear(dim_self, dim_self)
87
+ self.dropout = nn.Dropout(dropout)
88
+
89
+ def forward(self, x, y=None, mask=None):
90
+ y = y if y is not None else x
91
+ b, n, c = x.shape
92
+ _, m, d = y.shape
93
+ # b n h dh
94
+ queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
95
+ # b m 2 h dh
96
+ keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
97
+ keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
98
+ attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
99
+ if mask is not None:
100
+ if mask.dim() == 2:
101
+ mask = mask.unsqueeze(1)
102
+ attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
103
+ attention = attention.softmax(dim=2)
104
+ out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
105
+ out = self.project(out)
106
+ return out, attention
107
+
108
+
109
+ class TransformerLayer(nn.Module):
110
+
111
+ def forward_with_attention(self, x, y=None, mask=None):
112
+ x_, attention = self.attn(self.norm1(x), y, mask)
113
+ x = x + x_
114
+ x = x + self.mlp(self.norm2(x))
115
+ return x, attention
116
+
117
+ def forward(self, x, y=None, mask=None):
118
+ x = x + self.attn(self.norm1(x), y, mask)[0]
119
+ x = x + self.mlp(self.norm2(x))
120
+ return x
121
+
122
+ def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
123
+ norm_layer: nn.Module = nn.LayerNorm):
124
+ super().__init__()
125
+ self.norm1 = norm_layer(dim_self)
126
+ self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
127
+ self.norm2 = norm_layer(dim_self)
128
+ self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
129
+
130
+
131
+ class Transformer(nn.Module):
132
+
133
+ def forward_with_attention(self, x, y=None, mask=None):
134
+ attentions = []
135
+ for layer in self.layers:
136
+ x, att = layer.forward_with_attention(x, y, mask)
137
+ attentions.append(att)
138
+ return x, attentions
139
+
140
+ def forward(self, x, y=None, mask=None):
141
+ for i, layer in enumerate(self.layers):
142
+ if i % 2 == 0 and self.enc_dec: # cross
143
+ x = layer(x, y)
144
+ elif self.enc_dec: # self
145
+ x = layer(x, x, mask)
146
+ else: # self or cross
147
+ x = layer(x, y, mask)
148
+ return x
149
+
150
+ def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
151
+ mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
152
+ super(Transformer, self).__init__()
153
+ dim_ref = dim_ref if dim_ref is not None else dim_self
154
+ self.enc_dec = enc_dec
155
+ if enc_dec:
156
+ num_layers = num_layers * 2
157
+ layers = []
158
+ for i in range(num_layers):
159
+ if i % 2 == 0 and enc_dec: # cross
160
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
161
+ elif enc_dec: # self
162
+ layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
163
+ else: # self or cross
164
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
165
+ self.layers = nn.ModuleList(layers)
166
+
167
+
168
+ class TransformerMapper(nn.Module):
169
+
170
+ def forward(self, x):
171
+ x = self.linear(x).view(x.shape[0], self.clip_length, -1)
172
+ prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
173
+ prefix = torch.cat((x, prefix), dim=1)
174
+ out = self.transformer(prefix)[:, self.clip_length:]
175
+ return out
176
+
177
+ def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
178
+ super(TransformerMapper, self).__init__()
179
+ self.clip_length = clip_length
180
+ self.transformer = Transformer(dim_embedding, 8, num_layers)
181
+ self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
182
+ self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
183
+
184
+
185
+ class ClipCaptionModel(nn.Module):
186
+
187
+ def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
188
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
189
+
190
+ def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
191
+ labels: Optional[torch.Tensor] = None):
192
+ embedding_text = self.gpt.transformer.wte(tokens)
193
+ if prefix is not None:
194
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
195
+ embedding_text = torch.cat((prefix_projections, embedding_text), dim=1)
196
+ if labels is not None:
197
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
198
+ labels = torch.cat((dummy_token, tokens), dim=1)
199
+ out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask)
200
+ return out
201
+
202
+ def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
203
+ num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
204
+ super(ClipCaptionModel, self).__init__()
205
+ self.prefix_size = prefix_size
206
+ self.prefix_length = prefix_length
207
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
208
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
209
+ if mapping_type == MappingType.MLP:
210
+ self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
211
+ self.gpt_embedding_size * prefix_length))
212
+ else:
213
+ self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
214
+ clip_length, num_layers)
215
+ class ClipCaptionPrefix(ClipCaptionModel):
216
+ def parameters(self, recurse: bool = True):
217
+ return self.clip_project.parameters()
218
+
219
+ def train(self, mode: bool = True):
220
+ super(ClipCaptionPrefix, self).train(mode)
221
+ self.gpt.eval()
222
+ return self
223
+
224
+
225
+ def generate_beam(
226
+ model,
227
+ tokenizer,
228
+ beam_size: int = 5,
229
+ prompt=None,
230
+ embed=None,
231
+ #entry_length=67,
232
+ entry_length=150,
233
+ #temperature=1.0,
234
+ temperature=0.7,
235
+ stop_token: str = ".",
236
+ no_repeat_ngram = 3,
237
+ #no_repeat_ngram = None,
238
+ ):
239
+
240
+ model.eval()
241
+ stop_token_index = tokenizer.encode(stop_token)[0]
242
+ tokens = None
243
+ scores = None
244
+ device = next(model.parameters()).device
245
+ seq_lengths = torch.ones(beam_size, device=device)
246
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
247
+ filter_value = -float("Inf")
248
+ with torch.no_grad():
249
+ if embed is not None:
250
+ generated = embed
251
+ else:
252
+ if tokens is None:
253
+ tokens = torch.tensor(tokenizer.encode(prompt))
254
+ tokens = tokens.unsqueeze(0).to(device)
255
+ generated = model.gpt.transformer.wte(tokens)
256
+
257
+ stop_seq = tokenizer.encode('<STOP>')
258
+
259
+ for i in range(entry_length):
260
+ outputs = model.gpt(inputs_embeds=generated)
261
+ logits = outputs.logits
262
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
263
+ logits = logits.softmax(-1).log()
264
+ # prevent repeated ngrams
265
+ if no_repeat_ngram is not None:
266
+ if tokens is not None:
267
+ for b in range(beam_size):
268
+ tokens_list = tokens[b].tolist()
269
+ for idx in range(len(tokens_list) - no_repeat_ngram):
270
+ subseq = tokens_list[idx:idx+no_repeat_ngram]
271
+ if tokens_list[-no_repeat_ngram+1:] == subseq[:-1] and subseq[-1] not in stop_seq:
272
+ logits[b, subseq[-1]] = filter_value
273
+ if scores is None:
274
+ scores, next_tokens = logits.topk(beam_size, -1)
275
+ generated = generated.expand(beam_size, *generated.shape[1:])
276
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
277
+ if tokens is None:
278
+ tokens = next_tokens
279
+ else:
280
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
281
+ tokens = torch.cat((tokens, next_tokens), dim=1)
282
+ else:
283
+ logits[is_stopped] = -float(np.inf)
284
+ logits[is_stopped, 0] = 0
285
+ scores_sum = scores[:, None] + logits
286
+ seq_lengths[~is_stopped] += 1
287
+ scores_sum_average = scores_sum / seq_lengths[:, None]
288
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
289
+ beam_size, -1
290
+ )
291
+ next_tokens_source = next_tokens // scores_sum.shape[1]
292
+ seq_lengths = seq_lengths[next_tokens_source]
293
+ next_tokens = next_tokens % scores_sum.shape[1]
294
+ next_tokens = next_tokens.unsqueeze(1)
295
+ tokens = tokens[next_tokens_source]
296
+ tokens = torch.cat((tokens, next_tokens), dim=1)
297
+ generated = generated[next_tokens_source]
298
+ scores = scores_sum_average * seq_lengths
299
+ is_stopped = is_stopped[next_tokens_source]
300
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
301
+ generated.shape[0], 1, -1
302
+ )
303
+ generated = torch.cat((generated, next_token_embed), dim=1)
304
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
305
+ if is_stopped.all():
306
+ break
307
+ scores = scores / seq_lengths
308
+ output_list = tokens.cpu().numpy()
309
+ output_texts = [
310
+ tokenizer.decode(output[: int(length)])
311
+ for output, length in zip(output_list, seq_lengths)
312
+ ]
313
+ order = scores.argsort(descending=True)
314
+ output_texts = [output_texts[i] for i in order]
315
+ return output_texts
316
+
317
+
318
+ def generate2(
319
+ model,
320
+ tokenizer,
321
+ tokens=None,
322
+ prompt=None,
323
+ embed=None,
324
+ entry_count=1,
325
+ #entry_length=67, # maximum number of words
326
+ entry_length=150, # maximum number of words
327
+ top_p=0.8,
328
+ nucleus=False,
329
+ #temperature=1.0,
330
+ temperature=0.7,
331
+ stop_token: str = ".",
332
+ no_repeat_ngram = 3,
333
+ ):
334
+ model.eval()
335
+ generated_num = 0
336
+ generated_list = []
337
+ stop_token_index = tokenizer.encode(stop_token)[0]
338
+ filter_value = -1e10
339
+ device = next(model.parameters()).device
340
+
341
+ with torch.no_grad():
342
+
343
+ for entry_idx in range(entry_count):
344
+ if embed is not None:
345
+ generated = embed
346
+ else:
347
+ if tokens is None:
348
+ tokens = torch.tensor(tokenizer.encode(prompt))
349
+ tokens = tokens.unsqueeze(0).to(device)
350
+
351
+ generated = model.gpt.transformer.wte(tokens)
352
+
353
+ ngrams = defaultdict(lambda: set())
354
+ stop_seq = tokenizer.encode('<STOP>')
355
+
356
+ for i in range(entry_length):
357
+
358
+ outputs = model.gpt(inputs_embeds=generated)
359
+ logits = outputs.logits
360
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
361
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
362
+ cumulative_probs = torch.cumsum(
363
+ nnf.softmax(sorted_logits, dim=-1), dim=-1
364
+ )
365
+ sorted_indices_to_remove = cumulative_probs > top_p
366
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
367
+ ..., :-1
368
+ ].clone()
369
+ sorted_indices_to_remove[..., 0] = 0
370
+
371
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
372
+ logits[:, indices_to_remove] = filter_value
373
+ # remove any potential ngram repeats, unless part of <STOP>
374
+ if no_repeat_ngram is not None:
375
+ if tokens is not None:
376
+ for token in ngrams[tuple(tokens[0][-no_repeat_ngram+1:].tolist())]:
377
+ if token not in stop_seq:
378
+ logits[:, token] = filter_value
379
+ # either sample or argmax
380
+ if nucleus:
381
+ distr = torch.distributions.categorical.Categorical(logits=logits.squeeze())
382
+ next_token = distr.sample().unsqueeze(0).unsqueeze(0)
383
+ else:
384
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
385
+ next_token_embed = model.gpt.transformer.wte(next_token)
386
+ if logits[:, next_token].item() == filter_value:
387
+ break
388
+ # add to our set of ngrams
389
+ if no_repeat_ngram is not None:
390
+ if tokens is not None and len(tokens[0]) >= no_repeat_ngram - 1:
391
+ ngrams[tuple(tokens[0][-no_repeat_ngram+1:].tolist())].add(next_token.item())
392
+ if tokens is None:
393
+ tokens = next_token
394
+ else:
395
+ tokens = torch.cat((tokens, next_token), dim=1)
396
+ generated = torch.cat((generated, next_token_embed), dim=1)
397
+ if stop_token_index == next_token.item():
398
+ break
399
+
400
+
401
+ output_list = tokens.cpu().tolist()[0]
402
+ output_text = tokenizer.decode(output_list)
403
+ generated_list.append(output_text)
404
+
405
+ return generated_list[0]
data_module.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Ke Chen | [email protected] & Nikita Srivatsan | [email protected]
3
+ Load the mp3 format data from audiostock-full dataset
4
+ '''
5
+ import json
6
+ import numpy as np
7
+ import os
8
+ import pandas as pd
9
+ from pathlib import PurePosixPath
10
+ import random
11
+ import torch
12
+ import torchaudio
13
+ from torch.utils.data import Dataset
14
+ import sys
15
+
16
+ from lib import *
17
+ from utils import *
18
+
19
+ import torch.utils.data
20
+
21
+ def int16_to_float32(x):
22
+ return (x / 32767.0).type(torch.float)
23
+
24
+
25
+ def float32_to_int16(x):
26
+ x = torch.clip(x, min=-1., max=1.)
27
+ return (x * 32767.).type(torch.int16)
28
+
29
+ def my_collate(batch):
30
+ batch = [x for x in batch if x is not None]
31
+ if len(batch) == 0:
32
+ return batch
33
+ else:
34
+ return torch.utils.data.dataloader.default_collate(batch)
35
+
36
+ class AudiostockDataset(Dataset):
37
+ '''
38
+ Args:
39
+ dataset_path (str): the dataset folder path
40
+ train (bool): if True, we randomly return a 10-sec chunk from each audio file; if False, we return the middle 10-sec chunk (fixed)
41
+ split (str): a txt file to assign the idx in this dataset (for trainng, validation and testing)
42
+ factor (float): how many time we need to loop the whole dataset, this is to increase the number of training data batches in each epoch
43
+ whole_track (bool): if True, the dataset will return the full length of the audio file. However, this means the batch_size = 1, and it is usually in the test/validation case
44
+ '''
45
+ def __init__(self, dataset_path, tweet_prefix=True, prefix_length=10, normalize=False, dupefile='dupes.pkl', train = True, split = None, factor = 1.0, whole_track = False, verbose=True, dedup=True, file_list=[]):
46
+ super().__init__()
47
+ # set up parameters
48
+ self.max_seq_len = 150
49
+ self.tweet_prefix = tweet_prefix
50
+ if self.tweet_prefix:
51
+ self.max_seq_len *= 2
52
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=True)
53
+ self.prefix_length = prefix_length
54
+ self.normalize = normalize
55
+ self.id2neighbor = defaultdict(lambda: '')
56
+
57
+ if dedup:
58
+ if dupefile is not None and os.path.exists(dupefile):
59
+ with open(dupefile, 'rb') as dupefile:
60
+ self.is_rep = pickle.load(dupefile).is_rep
61
+ elif dupefile == 'both':
62
+ with open('dupes.pkl', 'rb') as dupefile:
63
+ dupes1 = pickle.load(dupefile)
64
+ with open('dupes_audio.pkl', 'rb') as dupefile:
65
+ dupes2 = pickle.load(dupefile)
66
+ self.is_rep = defaultdict(lambda: True)
67
+ for k,v in dupes1.is_rep.items():
68
+ self.is_rep[k] = v
69
+ for k,v in dupes2.is_rep.items():
70
+ self.is_rep[k] = v
71
+ else:
72
+ sys.exit('Could not find duplicate file')
73
+
74
+ subfolders = [f'audiostock-part-{i}' for i in range(1,9)]
75
+ self.label_path = os.path.join(dataset_path, 'audiostock-full-label')
76
+ self.whole_track = whole_track
77
+ self.file_list = file_list
78
+
79
+ # select out the elements for this split
80
+ if self.file_list == []:
81
+ temp_file_list = []
82
+ for subfolder in subfolders:
83
+ temp_file_list += [os.path.join(dataset_path, subfolder, f) for f in os.listdir(os.path.join(dataset_path, subfolder)) if not dedup or self.is_rep[os.path.basename(f).split('.')[0]]]
84
+ if split is not None:
85
+ split = set(np.loadtxt(split, dtype = str))
86
+ self.file_list = [f for f in temp_file_list if os.path.basename(f).split('.')[0] in split]
87
+ else:
88
+ self.file_list = temp_file_list
89
+
90
+ self.train = train
91
+ self.total_len = int(len(self.file_list) * factor)
92
+ if verbose:
93
+ print(f'Dataset Loaded | File Num.: {len(self.file_list)} | Batches per epoch: {self.total_len}')
94
+
95
+ def precompute_rand(self, candidate_set=None):
96
+ self.id2neighbor = defaultdict(lambda: '')
97
+ # if train
98
+ if candidate_set is None:
99
+ my_ids = []
100
+ candidate_caps = []
101
+ temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
102
+ for batch in temp_loader:
103
+ my_ids += batch['id']
104
+ candidate_caps += batch['short_text']
105
+ for idx in my_ids:
106
+ self.id2neighbor[idx] = random.choice(candidate_caps)
107
+ # if test
108
+ else:
109
+ temp_loader = DataLoader(candidate_set, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
110
+ candidate_caps = []
111
+ for batch in temp_loader:
112
+ candidate_caps += batch['short_text']
113
+ temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
114
+ my_ids = []
115
+ for batch in temp_loader:
116
+ my_ids += batch['id']
117
+ for idx in my_ids:
118
+ self.id2neighbor[idx] = random.choice(candidate_caps)
119
+
120
+ def precompute_gold(self):
121
+ self.id2neighbor = defaultdict(lambda: '')
122
+ temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
123
+ for batch in temp_loader:
124
+ for idx,short_text in zip(batch['id'], batch['short_text']):
125
+ self.id2neighbor[idx] = short_text
126
+
127
+ def precompute_blank(self):
128
+ self.id2neighbor = defaultdict(lambda: '\n')
129
+
130
+ def precompute_neighbors(self, model, candidate_set=None):
131
+ print('Precomputing neighbors')
132
+ self.id2neighbor = defaultdict(lambda: '')
133
+ # if train and model given
134
+ if candidate_set is None:
135
+ # compute waveform embeddings for each song
136
+ cand_features = None
137
+ cand_ids = []
138
+ cand_caps = []
139
+ temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
140
+ progress = tqdm(total=len(temp_loader), dynamic_ncols=True)
141
+ for batch in temp_loader:
142
+ with torch.no_grad():
143
+ batch_features = model.embed_waveform(batch['waveform'].cuda())
144
+ if cand_features is not None:
145
+ cand_features = torch.cat([cand_features, batch_features])
146
+ else:
147
+ cand_features = batch_features
148
+ cand_ids += batch['id']
149
+ cand_caps += batch['short_text']
150
+ progress.update()
151
+ progress.close()
152
+ my_features = cand_features
153
+ my_ids = cand_ids
154
+ # if test and model given
155
+ else:
156
+ # check if we already precomputed the embeddings
157
+ pickle_filename = 'nn_features.pkl'
158
+ if os.path.isfile(pickle_filename):
159
+ with open(pickle_filename, 'rb') as f:
160
+ (cand_features, cand_ids, cand_caps) = pickle.load(f)
161
+ else:
162
+ # build the features from the provided set instead of self
163
+ cand_features = None
164
+ cand_ids = []
165
+ cand_caps = []
166
+ temp_loader = DataLoader(candidate_set, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
167
+ progress = tqdm(total=len(temp_loader), dynamic_ncols=True)
168
+ for batch in temp_loader:
169
+ with torch.no_grad():
170
+ batch_features = model.embed_waveform(batch['waveform'].cuda())
171
+ if cand_features is not None:
172
+ cand_features = torch.cat([cand_features, batch_features])
173
+ else:
174
+ cand_features = batch_features
175
+ cand_ids += batch['id']
176
+ #cand_caps += [' '.join(x.split()[:10]) for x in batch['short_text']]
177
+ cand_caps += batch['short_text']
178
+ progress.update()
179
+ progress.close()
180
+ # dump to pickle so we don't have to redo this each time
181
+ with open(pickle_filename, 'wb') as f:
182
+ pickle.dump((cand_features, cand_ids, cand_caps), f)
183
+ # load up my own ids and features
184
+ my_features = None
185
+ my_ids = []
186
+ temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
187
+ progress = tqdm(total=len(temp_loader), dynamic_ncols=True)
188
+ for batch in temp_loader:
189
+ with torch.no_grad():
190
+ batch_features = model.embed_waveform(batch['waveform'].cuda())
191
+ if my_features is not None:
192
+ my_features = torch.cat([my_features, batch_features])
193
+ else:
194
+ my_features = batch_features
195
+ my_ids += batch['id']
196
+ progress.update()
197
+ progress.close()
198
+ is_self_sim = my_ids == cand_ids
199
+ for idx,audio_id in tqdm(enumerate(my_ids), total=len(my_ids), dynamic_ncols=True):
200
+ features = my_features[idx]
201
+ similarities = features @ cand_features.T
202
+ # remove identical matches
203
+ if is_self_sim:
204
+ similarities[idx] = float('-inf')
205
+ best_idx = torch.argmax(similarities)
206
+ most_similar_caption = cand_caps[best_idx]
207
+ self.id2neighbor[my_ids[idx]] = most_similar_caption
208
+
209
+ def pad_tokens(self, tokens, tokens_tweet):
210
+ tweet_text_len = 0
211
+ if self.tweet_prefix:
212
+ tweet_text_len = tokens_tweet[:self.max_seq_len // 2].shape[0]
213
+ tokens = torch.cat((tokens_tweet[:tweet_text_len], tokens))
214
+ padding = self.max_seq_len - tokens.shape[0]
215
+ if padding > 0:
216
+ tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
217
+ elif padding < 0:
218
+ tokens = tokens[:self.max_seq_len]
219
+ mask = tokens.ge(0) # mask is zero where we out of sequence
220
+ tokens[~mask] = 0
221
+ mask = mask.float()
222
+ mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask
223
+ return tokens, mask, tweet_text_len
224
+
225
+ def read_wav(self, filename):
226
+ stem = PurePosixPath(filename).stem
227
+ picklefile = f'wt-{self.whole_track}-t-{self.train}-{stem}.pt'
228
+ picklepath = f'/trunk/datasets/nsrivats/audiostock_proc/{picklefile}'
229
+ if os.path.exists(picklepath):
230
+ y = torch.load(picklepath)
231
+ else:
232
+ # chunk
233
+ try:
234
+ num_frames = torchaudio.info(filename).num_frames
235
+ except:
236
+ return None
237
+ # make sure it wasn't empty, if so die
238
+ if num_frames == 0:
239
+ return None
240
+ sta = 0
241
+ if not self.whole_track:
242
+ if self.train:
243
+ sta = random.randint(0, num_frames - 441001)
244
+ else:
245
+ sta = (num_frames - 441001) // 2
246
+ num_frames = 441000
247
+
248
+ y, sr = torchaudio.load(filename, frame_offset=sta, num_frames=num_frames)
249
+ # resample
250
+ y = torchaudio.functional.resample(y, sr, 48000)
251
+ y = y[:, :441000]
252
+ # mono
253
+ y = y.mean(dim=0)
254
+ # normalize
255
+ y = int16_to_float32(float32_to_int16(y))
256
+ # save
257
+ torch.save(y, picklepath)
258
+ return y
259
+
260
+ def __getitem__(self, index):
261
+ idx = index % len(self.file_list)
262
+ data_dict = {}
263
+ f = self.file_list[idx]
264
+ lf = os.path.join(self.label_path, os.path.basename(f).split('.')[0] + '.json')
265
+ data_dict['waveform'] = self.read_wav(f)
266
+ if os.path.isfile(lf):
267
+ with open(lf,'r') as label_file:
268
+ label_data = json.load(label_file)
269
+ data_dict['id'] = label_data['id']
270
+ data_dict['short_text'] = label_data['short_text']
271
+ if self.normalize:
272
+ data_dict['short_text'] = ' '.join(muscaps_tokenize(data_dict['short_text']))
273
+ if 'long_text' in label_data and label_data['long_text'] is not None:
274
+ data_dict['long_text'] = label_data['long_text']
275
+ else:
276
+ data_dict['long_text'] = ''
277
+ '''
278
+ data_dict['tag'] = label_data['tag']
279
+ data_dict['impression'] = label_data['impression']
280
+ data_dict['purpose'] = label_data['purpose']
281
+ '''
282
+ else:
283
+ data_dict['id'] = os.path.basename(f).split('.')[0]
284
+ data_dict['short_text'] = ''
285
+ data_dict['long_text'] = ''
286
+
287
+ # tokenize the caption
288
+ caption_proc = preproc(data_dict['short_text'], self.tokenizer)
289
+ tokens = torch.tensor(caption_proc, dtype=torch.int64)
290
+ tweet_text = self.id2neighbor[data_dict['id']] if self.tweet_prefix else ''
291
+ tweet_proc = preproc(tweet_text, self.tokenizer, stop=False)
292
+ tokens_tweet = torch.tensor(tweet_proc, dtype=torch.int64)
293
+ tokens, mask, tweet_text_len = self.pad_tokens(tokens, tokens_tweet)
294
+ data_dict['tokens'] = tokens
295
+ data_dict['mask'] = mask
296
+ data_dict['tweet_text_len'] = tweet_text_len
297
+ data_dict['tweet_text'] = tweet_text
298
+
299
+ if (data_dict['id'] is None or
300
+ data_dict['short_text'] is None or
301
+ data_dict['long_text'] is None or
302
+ data_dict['tokens'] is None or
303
+ data_dict['mask'] is None or
304
+ data_dict['tweet_text_len'] is None or
305
+ data_dict['tweet_text'] is None or
306
+ data_dict['waveform'] is None
307
+ ):
308
+ return None
309
+ else:
310
+ return data_dict
311
+
312
+ def __len__(self):
313
+ return self.total_len
314
+
315
+ class MusicCapsDataset(AudiostockDataset):
316
+ def __init__(self, dataset_path, args, train = True, split = None, factor = 1.0, whole_track = False, verbose=True, dedup=True):
317
+ super(AudiostockDataset, self).__init__()
318
+ # set up parameters
319
+ self.max_seq_len = 150
320
+ self.tweet_prefix = args.tweet_prefix
321
+ if self.tweet_prefix:
322
+ self.max_seq_len *= 2
323
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=True)
324
+ self.prefix_length = args.prefix_length
325
+ self.normalize = args.normalize
326
+ self.whole_track = whole_track
327
+
328
+ self.label_path = os.path.join(dataset_path, 'audio')
329
+ self.file_list = []
330
+ self.label_data = []
331
+ label_reader = pd.read_csv(f'{dataset_path}/musiccaps-resplit.csv')
332
+ for idx,row in label_reader.iterrows():
333
+ if (row['is_audioset_eval'] == 1 and split == 'musiccaps_eval') \
334
+ or (row['is_audioset_eval'] == 0 and split == 'musiccaps_train') \
335
+ or (row['is_audioset_eval'] == 2 and split == 'musiccaps_dev'):
336
+ data_dict = {}
337
+ data_dict['id'] = row['ytid']
338
+ self.file_list.append(f"{dataset_path}/audio/{data_dict['id']}.wav")
339
+ data_dict['short_text'] = row['caption']
340
+ if self.normalize:
341
+ data_dict['short_text'] = ' '.join(muscaps_tokenize(data_dict['short_text']))
342
+ data_dict['long_text'] = ''
343
+ data_dict['tag'] = row['aspect_list']
344
+ self.label_data.append(data_dict)
345
+
346
+ self.train = train
347
+ self.total_len = int(len(self.file_list) * factor)
348
+ if verbose:
349
+ print(f'Dataset Loaded | File Num.: {len(self.file_list)} | Batches per epoch: {self.total_len}')
350
+
351
+ def __getitem__(self, index):
352
+ idx = index % len(self.file_list)
353
+ data_dict = {}
354
+ f = self.file_list[idx]
355
+ data_dict['waveform'] = self.read_wav(f)
356
+ for k,v in self.label_data[idx].items():
357
+ data_dict[k] = v
358
+
359
+ # tokenize the caption
360
+ caption_proc = preproc(data_dict['short_text'], self.tokenizer)
361
+ tokens = torch.tensor(caption_proc, dtype=torch.int64)
362
+ tweet_text = self.id2neighbor[data_dict['id']] if self.tweet_prefix else ''
363
+ tweet_proc = preproc(tweet_text, self.tokenizer, stop=False)
364
+ tokens_tweet = torch.tensor(tweet_proc, dtype=torch.int64)
365
+ tokens, mask, tweet_text_len = self.pad_tokens(tokens, tokens_tweet)
366
+ data_dict['tokens'] = tokens
367
+ data_dict['mask'] = mask
368
+ data_dict['tweet_text_len'] = tweet_text_len
369
+ data_dict['tweet_text'] = tweet_text
370
+
371
+ if (data_dict['id'] is None or
372
+ data_dict['short_text'] is None or
373
+ data_dict['long_text'] is None or
374
+ data_dict['tokens'] is None or
375
+ data_dict['mask'] is None or
376
+ data_dict['tweet_text_len'] is None or
377
+ data_dict['tweet_text'] is None or
378
+ data_dict['waveform'] is None
379
+ ):
380
+ return None
381
+ else:
382
+ return data_dict
dupes.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e83b71d63cd11dc8840b44bcea625d1c618c8b421e4c6ec6c65580af5109c7bd
3
+ size 1807022
infer.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from audiocaptioner import AudioCaptioner
2
+ from data_module import AudiostockDataset
3
+ from utils import *
4
+
5
+ def infer(input_filename):
6
+ device = get_device(0)
7
+ # connect to GCS
8
+ gcs = CheckpointManager()
9
+ # create and/or load model
10
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=False)
11
+ prefix_dim = 512
12
+ prefix_length = 10
13
+ prefix_length_clip = 10
14
+ num_layers = 8
15
+ checkpoint = 'checkpoints/ZRIUE-BEST.pt'
16
+ model = AudioCaptioner(prefix_length, clip_length=prefix_length_clip, prefix_size=prefix_dim, num_layers=num_layers).to(device)
17
+ model.load_state_dict(gcs.get_checkpoint(checkpoint))
18
+ print(f'Loaded from {checkpoint}')
19
+ model.eval()
20
+ # read in the wav file and precompute neighbors
21
+ #dataset_path = '/graft1/datasets/kechen/audiostock-full'
22
+ dataset_path = ''
23
+ train_dataset = AudiostockDataset(
24
+ dataset_path=dataset_path,
25
+ train=False,
26
+ split='audiostock-train-240k.txt',
27
+ factor=1.0,
28
+ verbose=False,
29
+ file_list=open('audiostock-train-240k.txt', 'r').read().split()
30
+ )
31
+ print('Reading in file', input_filename)
32
+ dataset = AudiostockDataset(
33
+ dataset_path=dataset_path,
34
+ train=False,
35
+ split=None,
36
+ factor=1.0,
37
+ verbose=False,
38
+ file_list=[input_filename] # manually override file list
39
+ )
40
+ dataset.precompute_neighbors(model, candidate_set=train_dataset)
41
+ waveform = dataset.read_wav(input_filename).unsqueeze(0).to(device, dtype=torch.float32)
42
+ # predict
43
+ with torch.no_grad():
44
+ prefix_embed = model.create_prefix(waveform, 1)
45
+ tweet_tokens = torch.tensor(preproc(dataset.id2neighbor[os.path.basename(input_filename).split('.')[0]], tokenizer, stop=False), dtype=torch.int64).to(device)[:150]
46
+ tweet_embed = model.gpt.transformer.wte(tweet_tokens)
47
+ prefix_embed = torch.cat([prefix_embed, tweet_embed.unsqueeze(0)], dim=1)
48
+ candidates = generate_beam(model, tokenizer, embed=prefix_embed, beam_size=5)
49
+ generated_text = candidates[0]
50
+ generated_text = postproc(generated_text)
51
+ print('=======================================')
52
+ print(generated_text)
53
+
54
+ if __name__ == '__main__':
55
+ infer('../MusicCaptioning/sample_inputs/sisters.mp3')
lib.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import json
3
+ import numpy as np
4
+ import os
5
+ import pandas as pd
6
+ import dill as pickle
7
+ pickle._dill._reverse_typemap['ClassType'] = type
8
+ import random
9
+ import string
10
+ import sys
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as nnf
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from tqdm import tqdm
16
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
17
+ from typing import Tuple, List, Union, Optional
18
+
19
+ from clipcap import *
utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib import *
2
+ from twokenize import tokenizeRawTweetText
3
+ import re
4
+
5
+ def muscaps_tokenize(raw):
6
+ raw = raw.lower()
7
+ for punc in string.punctuation:
8
+ raw = raw.replace(punc, ' ')
9
+ tokens = raw.split()
10
+ return tokens
11
+
12
+ def get_device(device_id: int) -> torch.device:
13
+ if not torch.cuda.is_available():
14
+ return torch.device('cpu')
15
+ device_id = min(torch.cuda.device_count() - 1, device_id)
16
+ return torch.device(f'cuda:{device_id}')
17
+
18
+ def preproc(caption, tokenizer, stop=True):
19
+ caption = caption.replace('.', '<STOP>')
20
+ caption_proc = tokenizer.encode(caption)
21
+ if stop:
22
+ caption_proc += tokenizer.encode('.')
23
+ return caption_proc
24
+
25
+ def postproc(caption):
26
+ caption = caption.replace('<STOP>', '.')
27
+ if caption[-1] == '.':
28
+ caption = caption[:-1]
29
+ return caption
30
+
31
+ class CheckpointManager:
32
+ def __init__(self):
33
+ self.checkpoint_dir = '/home/nsrivats/Repositories/MusicCaptioning/checkpoints'
34
+
35
+ def get_checkpoint(self, checkpoint):
36
+ with open(checkpoint, 'rb') as infile:
37
+ return torch.load(infile)
38
+
39
+ def save_checkpoint(self, state_dict, checkpoint):
40
+ filename = f'{self.checkpoint_dir}/{checkpoint}'
41
+ with open(filename, 'wb') as outfile:
42
+ torch.save(state_dict, outfile)
43
+
44
+ def save_logs(self, logdir):
45
+ pass