Spaces:
Sleeping
Sleeping
NikitaSrivatsan
commited on
Commit
•
48ac659
1
Parent(s):
3d5b800
First pass at captioning functionality through web app
Browse files- .gitignore +1 -0
- app.py +2 -1
- audiocaptioner.py +68 -0
- audiostock-train-240k.txt +0 -0
- clipcap.py +405 -0
- data_module.py +382 -0
- dupes.pkl +3 -0
- infer.py +55 -0
- lib.py +19 -0
- utils.py +45 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pyc
|
app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
|
3 |
def greet(name):
|
4 |
return f'Hello {name}!!'
|
5 |
|
6 |
-
demo = gr.Interface(fn=
|
7 |
inputs=gr.Audio(sources='upload', type='filepath'),
|
8 |
outputs='text')
|
9 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from infer import infer
|
3 |
|
4 |
def greet(name):
|
5 |
return f'Hello {name}!!'
|
6 |
|
7 |
+
demo = gr.Interface(fn=infer,
|
8 |
inputs=gr.Audio(sources='upload', type='filepath'),
|
9 |
outputs='text')
|
10 |
demo.launch()
|
audiocaptioner.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lib import *
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
import io
|
5 |
+
import laion_clap
|
6 |
+
import torch
|
7 |
+
|
8 |
+
class AudioCaptioner(torch.nn.Module):
|
9 |
+
|
10 |
+
def get_dummy_token(self, batch_size: int) -> torch.Tensor:
|
11 |
+
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64)
|
12 |
+
|
13 |
+
def embed_waveform(self, waveform):
|
14 |
+
# compute the prefix
|
15 |
+
input_dict = {
|
16 |
+
'waveform': waveform # you can add more key-values
|
17 |
+
}
|
18 |
+
audio_embeds = self.clap_model.model.encode_audio(
|
19 |
+
input_dict,
|
20 |
+
device=waveform.device
|
21 |
+
)
|
22 |
+
|
23 |
+
# get BxD-dim embedding (last layer) D = 1024 -> 512 after audio projection
|
24 |
+
audio_embedding = torch.nn.functional.normalize(self.clap_model.model.audio_projection(audio_embeds['embedding']), dim=-1)
|
25 |
+
return audio_embedding
|
26 |
+
|
27 |
+
def create_prefix(self, waveform, batch_size):
|
28 |
+
if waveform is not None:
|
29 |
+
audio_embedding = self.embed_waveform(waveform)
|
30 |
+
else:
|
31 |
+
audio_embedding = torch.zeros(batch_size, self.prefix_size).cuda()
|
32 |
+
# project the prefix through map net and append it
|
33 |
+
prefix_projections = self.clip_project(audio_embedding).view(-1, self.prefix_length, self.gpt_embedding_size)
|
34 |
+
return prefix_projections
|
35 |
+
|
36 |
+
def forward(self, tokens: torch.Tensor, waveform: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
37 |
+
labels: Optional[torch.Tensor] = None, freeze_gpt = False):
|
38 |
+
# embed the text
|
39 |
+
embedding_text = self.gpt.transformer.wte(tokens)
|
40 |
+
prefix_projections = self.create_prefix(waveform, tokens.shape[0])
|
41 |
+
embedding_text = torch.cat((prefix_projections, embedding_text), dim=1)
|
42 |
+
# offset labels
|
43 |
+
if labels is not None:
|
44 |
+
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
|
45 |
+
labels = torch.cat((dummy_token, tokens), dim=1)
|
46 |
+
# push through GPT
|
47 |
+
if freeze_gpt:
|
48 |
+
with torch.no_grad():
|
49 |
+
out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask)
|
50 |
+
else:
|
51 |
+
out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask)
|
52 |
+
return out
|
53 |
+
|
54 |
+
def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
|
55 |
+
num_layers: int = 8):
|
56 |
+
super(AudioCaptioner, self).__init__()
|
57 |
+
self.prefix_size = prefix_size
|
58 |
+
self.prefix_length = prefix_length
|
59 |
+
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
|
60 |
+
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
61 |
+
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
|
62 |
+
self.gpt_embedding_size * prefix_length))
|
63 |
+
self.clap_model = laion_clap.CLAP_Module(
|
64 |
+
enable_fusion=False,
|
65 |
+
amodel = 'HTSAT-base'
|
66 |
+
)
|
67 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
68 |
+
self.clap_model.load_ckpt(ckpt = '/graft1/datasets/kechen/clap_ckpt/music_audioset_epoch_15_esc_90.14.pt')
|
audiostock-train-240k.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
clipcap.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####################################################################
|
2 |
+
### Credit: Ron Mokady / rmokady ###
|
3 |
+
### Original Repo: https://github.com/rmokady/CLIP_prefix_caption ###
|
4 |
+
#####################################################################
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from collections import defaultdict
|
8 |
+
import os
|
9 |
+
from torch import nn
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as nnf
|
13 |
+
import sys
|
14 |
+
from typing import Tuple, List, Union, Optional
|
15 |
+
from transformers import (
|
16 |
+
GPT2Tokenizer,
|
17 |
+
GPT2LMHeadModel,
|
18 |
+
AdamW,
|
19 |
+
get_linear_schedule_with_warmup,
|
20 |
+
)
|
21 |
+
|
22 |
+
# import torch
|
23 |
+
|
24 |
+
N = type(None)
|
25 |
+
V = np.array
|
26 |
+
ARRAY = np.ndarray
|
27 |
+
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
|
28 |
+
VS = Union[Tuple[V, ...], List[V]]
|
29 |
+
VN = Union[V, N]
|
30 |
+
VNS = Union[VS, N]
|
31 |
+
T = torch.Tensor
|
32 |
+
TS = Union[Tuple[T, ...], List[T]]
|
33 |
+
TN = Optional[T]
|
34 |
+
TNS = Union[Tuple[TN, ...], List[TN]]
|
35 |
+
TSN = Optional[TS]
|
36 |
+
TA = Union[T, ARRAY]
|
37 |
+
|
38 |
+
WEIGHTS_PATHS = {
|
39 |
+
"coco": "coco_weights.pt",
|
40 |
+
"conceptual-captions": "conceptual_weights.pt",
|
41 |
+
}
|
42 |
+
|
43 |
+
class MappingType(Enum):
|
44 |
+
MLP = 'mlp'
|
45 |
+
Transformer = 'transformer'
|
46 |
+
|
47 |
+
class MLP(nn.Module):
|
48 |
+
def forward(self, x: T) -> T:
|
49 |
+
return self.model(x)
|
50 |
+
|
51 |
+
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
52 |
+
super(MLP, self).__init__()
|
53 |
+
layers = []
|
54 |
+
for i in range(len(sizes) - 1):
|
55 |
+
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
56 |
+
if i < len(sizes) - 2:
|
57 |
+
layers.append(act())
|
58 |
+
self.model = nn.Sequential(*layers)
|
59 |
+
|
60 |
+
class MlpTransformer(nn.Module):
|
61 |
+
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
|
62 |
+
super().__init__()
|
63 |
+
out_d = out_d if out_d is not None else in_dim
|
64 |
+
self.fc1 = nn.Linear(in_dim, h_dim)
|
65 |
+
self.act = act
|
66 |
+
self.fc2 = nn.Linear(h_dim, out_d)
|
67 |
+
self.dropout = nn.Dropout(dropout)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = self.fc1(x)
|
71 |
+
x = self.act(x)
|
72 |
+
x = self.dropout(x)
|
73 |
+
x = self.fc2(x)
|
74 |
+
x = self.dropout(x)
|
75 |
+
return x
|
76 |
+
|
77 |
+
class MultiHeadAttention(nn.Module):
|
78 |
+
|
79 |
+
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
|
80 |
+
super().__init__()
|
81 |
+
self.num_heads = num_heads
|
82 |
+
head_dim = dim_self // num_heads
|
83 |
+
self.scale = head_dim ** -0.5
|
84 |
+
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
|
85 |
+
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
|
86 |
+
self.project = nn.Linear(dim_self, dim_self)
|
87 |
+
self.dropout = nn.Dropout(dropout)
|
88 |
+
|
89 |
+
def forward(self, x, y=None, mask=None):
|
90 |
+
y = y if y is not None else x
|
91 |
+
b, n, c = x.shape
|
92 |
+
_, m, d = y.shape
|
93 |
+
# b n h dh
|
94 |
+
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
|
95 |
+
# b m 2 h dh
|
96 |
+
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
|
97 |
+
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
|
98 |
+
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
|
99 |
+
if mask is not None:
|
100 |
+
if mask.dim() == 2:
|
101 |
+
mask = mask.unsqueeze(1)
|
102 |
+
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
|
103 |
+
attention = attention.softmax(dim=2)
|
104 |
+
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
|
105 |
+
out = self.project(out)
|
106 |
+
return out, attention
|
107 |
+
|
108 |
+
|
109 |
+
class TransformerLayer(nn.Module):
|
110 |
+
|
111 |
+
def forward_with_attention(self, x, y=None, mask=None):
|
112 |
+
x_, attention = self.attn(self.norm1(x), y, mask)
|
113 |
+
x = x + x_
|
114 |
+
x = x + self.mlp(self.norm2(x))
|
115 |
+
return x, attention
|
116 |
+
|
117 |
+
def forward(self, x, y=None, mask=None):
|
118 |
+
x = x + self.attn(self.norm1(x), y, mask)[0]
|
119 |
+
x = x + self.mlp(self.norm2(x))
|
120 |
+
return x
|
121 |
+
|
122 |
+
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
|
123 |
+
norm_layer: nn.Module = nn.LayerNorm):
|
124 |
+
super().__init__()
|
125 |
+
self.norm1 = norm_layer(dim_self)
|
126 |
+
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
|
127 |
+
self.norm2 = norm_layer(dim_self)
|
128 |
+
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
|
129 |
+
|
130 |
+
|
131 |
+
class Transformer(nn.Module):
|
132 |
+
|
133 |
+
def forward_with_attention(self, x, y=None, mask=None):
|
134 |
+
attentions = []
|
135 |
+
for layer in self.layers:
|
136 |
+
x, att = layer.forward_with_attention(x, y, mask)
|
137 |
+
attentions.append(att)
|
138 |
+
return x, attentions
|
139 |
+
|
140 |
+
def forward(self, x, y=None, mask=None):
|
141 |
+
for i, layer in enumerate(self.layers):
|
142 |
+
if i % 2 == 0 and self.enc_dec: # cross
|
143 |
+
x = layer(x, y)
|
144 |
+
elif self.enc_dec: # self
|
145 |
+
x = layer(x, x, mask)
|
146 |
+
else: # self or cross
|
147 |
+
x = layer(x, y, mask)
|
148 |
+
return x
|
149 |
+
|
150 |
+
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
|
151 |
+
mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
|
152 |
+
super(Transformer, self).__init__()
|
153 |
+
dim_ref = dim_ref if dim_ref is not None else dim_self
|
154 |
+
self.enc_dec = enc_dec
|
155 |
+
if enc_dec:
|
156 |
+
num_layers = num_layers * 2
|
157 |
+
layers = []
|
158 |
+
for i in range(num_layers):
|
159 |
+
if i % 2 == 0 and enc_dec: # cross
|
160 |
+
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
161 |
+
elif enc_dec: # self
|
162 |
+
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
163 |
+
else: # self or cross
|
164 |
+
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
165 |
+
self.layers = nn.ModuleList(layers)
|
166 |
+
|
167 |
+
|
168 |
+
class TransformerMapper(nn.Module):
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
|
172 |
+
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
|
173 |
+
prefix = torch.cat((x, prefix), dim=1)
|
174 |
+
out = self.transformer(prefix)[:, self.clip_length:]
|
175 |
+
return out
|
176 |
+
|
177 |
+
def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
|
178 |
+
super(TransformerMapper, self).__init__()
|
179 |
+
self.clip_length = clip_length
|
180 |
+
self.transformer = Transformer(dim_embedding, 8, num_layers)
|
181 |
+
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
|
182 |
+
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
|
183 |
+
|
184 |
+
|
185 |
+
class ClipCaptionModel(nn.Module):
|
186 |
+
|
187 |
+
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
|
188 |
+
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
|
189 |
+
|
190 |
+
def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
191 |
+
labels: Optional[torch.Tensor] = None):
|
192 |
+
embedding_text = self.gpt.transformer.wte(tokens)
|
193 |
+
if prefix is not None:
|
194 |
+
prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
|
195 |
+
embedding_text = torch.cat((prefix_projections, embedding_text), dim=1)
|
196 |
+
if labels is not None:
|
197 |
+
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
|
198 |
+
labels = torch.cat((dummy_token, tokens), dim=1)
|
199 |
+
out = self.gpt(inputs_embeds=embedding_text, labels=labels, attention_mask=mask)
|
200 |
+
return out
|
201 |
+
|
202 |
+
def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
|
203 |
+
num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
|
204 |
+
super(ClipCaptionModel, self).__init__()
|
205 |
+
self.prefix_size = prefix_size
|
206 |
+
self.prefix_length = prefix_length
|
207 |
+
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
|
208 |
+
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
209 |
+
if mapping_type == MappingType.MLP:
|
210 |
+
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
|
211 |
+
self.gpt_embedding_size * prefix_length))
|
212 |
+
else:
|
213 |
+
self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
|
214 |
+
clip_length, num_layers)
|
215 |
+
class ClipCaptionPrefix(ClipCaptionModel):
|
216 |
+
def parameters(self, recurse: bool = True):
|
217 |
+
return self.clip_project.parameters()
|
218 |
+
|
219 |
+
def train(self, mode: bool = True):
|
220 |
+
super(ClipCaptionPrefix, self).train(mode)
|
221 |
+
self.gpt.eval()
|
222 |
+
return self
|
223 |
+
|
224 |
+
|
225 |
+
def generate_beam(
|
226 |
+
model,
|
227 |
+
tokenizer,
|
228 |
+
beam_size: int = 5,
|
229 |
+
prompt=None,
|
230 |
+
embed=None,
|
231 |
+
#entry_length=67,
|
232 |
+
entry_length=150,
|
233 |
+
#temperature=1.0,
|
234 |
+
temperature=0.7,
|
235 |
+
stop_token: str = ".",
|
236 |
+
no_repeat_ngram = 3,
|
237 |
+
#no_repeat_ngram = None,
|
238 |
+
):
|
239 |
+
|
240 |
+
model.eval()
|
241 |
+
stop_token_index = tokenizer.encode(stop_token)[0]
|
242 |
+
tokens = None
|
243 |
+
scores = None
|
244 |
+
device = next(model.parameters()).device
|
245 |
+
seq_lengths = torch.ones(beam_size, device=device)
|
246 |
+
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
|
247 |
+
filter_value = -float("Inf")
|
248 |
+
with torch.no_grad():
|
249 |
+
if embed is not None:
|
250 |
+
generated = embed
|
251 |
+
else:
|
252 |
+
if tokens is None:
|
253 |
+
tokens = torch.tensor(tokenizer.encode(prompt))
|
254 |
+
tokens = tokens.unsqueeze(0).to(device)
|
255 |
+
generated = model.gpt.transformer.wte(tokens)
|
256 |
+
|
257 |
+
stop_seq = tokenizer.encode('<STOP>')
|
258 |
+
|
259 |
+
for i in range(entry_length):
|
260 |
+
outputs = model.gpt(inputs_embeds=generated)
|
261 |
+
logits = outputs.logits
|
262 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
263 |
+
logits = logits.softmax(-1).log()
|
264 |
+
# prevent repeated ngrams
|
265 |
+
if no_repeat_ngram is not None:
|
266 |
+
if tokens is not None:
|
267 |
+
for b in range(beam_size):
|
268 |
+
tokens_list = tokens[b].tolist()
|
269 |
+
for idx in range(len(tokens_list) - no_repeat_ngram):
|
270 |
+
subseq = tokens_list[idx:idx+no_repeat_ngram]
|
271 |
+
if tokens_list[-no_repeat_ngram+1:] == subseq[:-1] and subseq[-1] not in stop_seq:
|
272 |
+
logits[b, subseq[-1]] = filter_value
|
273 |
+
if scores is None:
|
274 |
+
scores, next_tokens = logits.topk(beam_size, -1)
|
275 |
+
generated = generated.expand(beam_size, *generated.shape[1:])
|
276 |
+
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
|
277 |
+
if tokens is None:
|
278 |
+
tokens = next_tokens
|
279 |
+
else:
|
280 |
+
tokens = tokens.expand(beam_size, *tokens.shape[1:])
|
281 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
282 |
+
else:
|
283 |
+
logits[is_stopped] = -float(np.inf)
|
284 |
+
logits[is_stopped, 0] = 0
|
285 |
+
scores_sum = scores[:, None] + logits
|
286 |
+
seq_lengths[~is_stopped] += 1
|
287 |
+
scores_sum_average = scores_sum / seq_lengths[:, None]
|
288 |
+
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
|
289 |
+
beam_size, -1
|
290 |
+
)
|
291 |
+
next_tokens_source = next_tokens // scores_sum.shape[1]
|
292 |
+
seq_lengths = seq_lengths[next_tokens_source]
|
293 |
+
next_tokens = next_tokens % scores_sum.shape[1]
|
294 |
+
next_tokens = next_tokens.unsqueeze(1)
|
295 |
+
tokens = tokens[next_tokens_source]
|
296 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
297 |
+
generated = generated[next_tokens_source]
|
298 |
+
scores = scores_sum_average * seq_lengths
|
299 |
+
is_stopped = is_stopped[next_tokens_source]
|
300 |
+
next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
|
301 |
+
generated.shape[0], 1, -1
|
302 |
+
)
|
303 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
304 |
+
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
|
305 |
+
if is_stopped.all():
|
306 |
+
break
|
307 |
+
scores = scores / seq_lengths
|
308 |
+
output_list = tokens.cpu().numpy()
|
309 |
+
output_texts = [
|
310 |
+
tokenizer.decode(output[: int(length)])
|
311 |
+
for output, length in zip(output_list, seq_lengths)
|
312 |
+
]
|
313 |
+
order = scores.argsort(descending=True)
|
314 |
+
output_texts = [output_texts[i] for i in order]
|
315 |
+
return output_texts
|
316 |
+
|
317 |
+
|
318 |
+
def generate2(
|
319 |
+
model,
|
320 |
+
tokenizer,
|
321 |
+
tokens=None,
|
322 |
+
prompt=None,
|
323 |
+
embed=None,
|
324 |
+
entry_count=1,
|
325 |
+
#entry_length=67, # maximum number of words
|
326 |
+
entry_length=150, # maximum number of words
|
327 |
+
top_p=0.8,
|
328 |
+
nucleus=False,
|
329 |
+
#temperature=1.0,
|
330 |
+
temperature=0.7,
|
331 |
+
stop_token: str = ".",
|
332 |
+
no_repeat_ngram = 3,
|
333 |
+
):
|
334 |
+
model.eval()
|
335 |
+
generated_num = 0
|
336 |
+
generated_list = []
|
337 |
+
stop_token_index = tokenizer.encode(stop_token)[0]
|
338 |
+
filter_value = -1e10
|
339 |
+
device = next(model.parameters()).device
|
340 |
+
|
341 |
+
with torch.no_grad():
|
342 |
+
|
343 |
+
for entry_idx in range(entry_count):
|
344 |
+
if embed is not None:
|
345 |
+
generated = embed
|
346 |
+
else:
|
347 |
+
if tokens is None:
|
348 |
+
tokens = torch.tensor(tokenizer.encode(prompt))
|
349 |
+
tokens = tokens.unsqueeze(0).to(device)
|
350 |
+
|
351 |
+
generated = model.gpt.transformer.wte(tokens)
|
352 |
+
|
353 |
+
ngrams = defaultdict(lambda: set())
|
354 |
+
stop_seq = tokenizer.encode('<STOP>')
|
355 |
+
|
356 |
+
for i in range(entry_length):
|
357 |
+
|
358 |
+
outputs = model.gpt(inputs_embeds=generated)
|
359 |
+
logits = outputs.logits
|
360 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
361 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
362 |
+
cumulative_probs = torch.cumsum(
|
363 |
+
nnf.softmax(sorted_logits, dim=-1), dim=-1
|
364 |
+
)
|
365 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
366 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
367 |
+
..., :-1
|
368 |
+
].clone()
|
369 |
+
sorted_indices_to_remove[..., 0] = 0
|
370 |
+
|
371 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
372 |
+
logits[:, indices_to_remove] = filter_value
|
373 |
+
# remove any potential ngram repeats, unless part of <STOP>
|
374 |
+
if no_repeat_ngram is not None:
|
375 |
+
if tokens is not None:
|
376 |
+
for token in ngrams[tuple(tokens[0][-no_repeat_ngram+1:].tolist())]:
|
377 |
+
if token not in stop_seq:
|
378 |
+
logits[:, token] = filter_value
|
379 |
+
# either sample or argmax
|
380 |
+
if nucleus:
|
381 |
+
distr = torch.distributions.categorical.Categorical(logits=logits.squeeze())
|
382 |
+
next_token = distr.sample().unsqueeze(0).unsqueeze(0)
|
383 |
+
else:
|
384 |
+
next_token = torch.argmax(logits, -1).unsqueeze(0)
|
385 |
+
next_token_embed = model.gpt.transformer.wte(next_token)
|
386 |
+
if logits[:, next_token].item() == filter_value:
|
387 |
+
break
|
388 |
+
# add to our set of ngrams
|
389 |
+
if no_repeat_ngram is not None:
|
390 |
+
if tokens is not None and len(tokens[0]) >= no_repeat_ngram - 1:
|
391 |
+
ngrams[tuple(tokens[0][-no_repeat_ngram+1:].tolist())].add(next_token.item())
|
392 |
+
if tokens is None:
|
393 |
+
tokens = next_token
|
394 |
+
else:
|
395 |
+
tokens = torch.cat((tokens, next_token), dim=1)
|
396 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
397 |
+
if stop_token_index == next_token.item():
|
398 |
+
break
|
399 |
+
|
400 |
+
|
401 |
+
output_list = tokens.cpu().tolist()[0]
|
402 |
+
output_text = tokenizer.decode(output_list)
|
403 |
+
generated_list.append(output_text)
|
404 |
+
|
405 |
+
return generated_list[0]
|
data_module.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Ke Chen | [email protected] & Nikita Srivatsan | [email protected]
|
3 |
+
Load the mp3 format data from audiostock-full dataset
|
4 |
+
'''
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import pandas as pd
|
9 |
+
from pathlib import PurePosixPath
|
10 |
+
import random
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
import sys
|
15 |
+
|
16 |
+
from lib import *
|
17 |
+
from utils import *
|
18 |
+
|
19 |
+
import torch.utils.data
|
20 |
+
|
21 |
+
def int16_to_float32(x):
|
22 |
+
return (x / 32767.0).type(torch.float)
|
23 |
+
|
24 |
+
|
25 |
+
def float32_to_int16(x):
|
26 |
+
x = torch.clip(x, min=-1., max=1.)
|
27 |
+
return (x * 32767.).type(torch.int16)
|
28 |
+
|
29 |
+
def my_collate(batch):
|
30 |
+
batch = [x for x in batch if x is not None]
|
31 |
+
if len(batch) == 0:
|
32 |
+
return batch
|
33 |
+
else:
|
34 |
+
return torch.utils.data.dataloader.default_collate(batch)
|
35 |
+
|
36 |
+
class AudiostockDataset(Dataset):
|
37 |
+
'''
|
38 |
+
Args:
|
39 |
+
dataset_path (str): the dataset folder path
|
40 |
+
train (bool): if True, we randomly return a 10-sec chunk from each audio file; if False, we return the middle 10-sec chunk (fixed)
|
41 |
+
split (str): a txt file to assign the idx in this dataset (for trainng, validation and testing)
|
42 |
+
factor (float): how many time we need to loop the whole dataset, this is to increase the number of training data batches in each epoch
|
43 |
+
whole_track (bool): if True, the dataset will return the full length of the audio file. However, this means the batch_size = 1, and it is usually in the test/validation case
|
44 |
+
'''
|
45 |
+
def __init__(self, dataset_path, tweet_prefix=True, prefix_length=10, normalize=False, dupefile='dupes.pkl', train = True, split = None, factor = 1.0, whole_track = False, verbose=True, dedup=True, file_list=[]):
|
46 |
+
super().__init__()
|
47 |
+
# set up parameters
|
48 |
+
self.max_seq_len = 150
|
49 |
+
self.tweet_prefix = tweet_prefix
|
50 |
+
if self.tweet_prefix:
|
51 |
+
self.max_seq_len *= 2
|
52 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=True)
|
53 |
+
self.prefix_length = prefix_length
|
54 |
+
self.normalize = normalize
|
55 |
+
self.id2neighbor = defaultdict(lambda: '')
|
56 |
+
|
57 |
+
if dedup:
|
58 |
+
if dupefile is not None and os.path.exists(dupefile):
|
59 |
+
with open(dupefile, 'rb') as dupefile:
|
60 |
+
self.is_rep = pickle.load(dupefile).is_rep
|
61 |
+
elif dupefile == 'both':
|
62 |
+
with open('dupes.pkl', 'rb') as dupefile:
|
63 |
+
dupes1 = pickle.load(dupefile)
|
64 |
+
with open('dupes_audio.pkl', 'rb') as dupefile:
|
65 |
+
dupes2 = pickle.load(dupefile)
|
66 |
+
self.is_rep = defaultdict(lambda: True)
|
67 |
+
for k,v in dupes1.is_rep.items():
|
68 |
+
self.is_rep[k] = v
|
69 |
+
for k,v in dupes2.is_rep.items():
|
70 |
+
self.is_rep[k] = v
|
71 |
+
else:
|
72 |
+
sys.exit('Could not find duplicate file')
|
73 |
+
|
74 |
+
subfolders = [f'audiostock-part-{i}' for i in range(1,9)]
|
75 |
+
self.label_path = os.path.join(dataset_path, 'audiostock-full-label')
|
76 |
+
self.whole_track = whole_track
|
77 |
+
self.file_list = file_list
|
78 |
+
|
79 |
+
# select out the elements for this split
|
80 |
+
if self.file_list == []:
|
81 |
+
temp_file_list = []
|
82 |
+
for subfolder in subfolders:
|
83 |
+
temp_file_list += [os.path.join(dataset_path, subfolder, f) for f in os.listdir(os.path.join(dataset_path, subfolder)) if not dedup or self.is_rep[os.path.basename(f).split('.')[0]]]
|
84 |
+
if split is not None:
|
85 |
+
split = set(np.loadtxt(split, dtype = str))
|
86 |
+
self.file_list = [f for f in temp_file_list if os.path.basename(f).split('.')[0] in split]
|
87 |
+
else:
|
88 |
+
self.file_list = temp_file_list
|
89 |
+
|
90 |
+
self.train = train
|
91 |
+
self.total_len = int(len(self.file_list) * factor)
|
92 |
+
if verbose:
|
93 |
+
print(f'Dataset Loaded | File Num.: {len(self.file_list)} | Batches per epoch: {self.total_len}')
|
94 |
+
|
95 |
+
def precompute_rand(self, candidate_set=None):
|
96 |
+
self.id2neighbor = defaultdict(lambda: '')
|
97 |
+
# if train
|
98 |
+
if candidate_set is None:
|
99 |
+
my_ids = []
|
100 |
+
candidate_caps = []
|
101 |
+
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
|
102 |
+
for batch in temp_loader:
|
103 |
+
my_ids += batch['id']
|
104 |
+
candidate_caps += batch['short_text']
|
105 |
+
for idx in my_ids:
|
106 |
+
self.id2neighbor[idx] = random.choice(candidate_caps)
|
107 |
+
# if test
|
108 |
+
else:
|
109 |
+
temp_loader = DataLoader(candidate_set, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
|
110 |
+
candidate_caps = []
|
111 |
+
for batch in temp_loader:
|
112 |
+
candidate_caps += batch['short_text']
|
113 |
+
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
|
114 |
+
my_ids = []
|
115 |
+
for batch in temp_loader:
|
116 |
+
my_ids += batch['id']
|
117 |
+
for idx in my_ids:
|
118 |
+
self.id2neighbor[idx] = random.choice(candidate_caps)
|
119 |
+
|
120 |
+
def precompute_gold(self):
|
121 |
+
self.id2neighbor = defaultdict(lambda: '')
|
122 |
+
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
|
123 |
+
for batch in temp_loader:
|
124 |
+
for idx,short_text in zip(batch['id'], batch['short_text']):
|
125 |
+
self.id2neighbor[idx] = short_text
|
126 |
+
|
127 |
+
def precompute_blank(self):
|
128 |
+
self.id2neighbor = defaultdict(lambda: '\n')
|
129 |
+
|
130 |
+
def precompute_neighbors(self, model, candidate_set=None):
|
131 |
+
print('Precomputing neighbors')
|
132 |
+
self.id2neighbor = defaultdict(lambda: '')
|
133 |
+
# if train and model given
|
134 |
+
if candidate_set is None:
|
135 |
+
# compute waveform embeddings for each song
|
136 |
+
cand_features = None
|
137 |
+
cand_ids = []
|
138 |
+
cand_caps = []
|
139 |
+
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
|
140 |
+
progress = tqdm(total=len(temp_loader), dynamic_ncols=True)
|
141 |
+
for batch in temp_loader:
|
142 |
+
with torch.no_grad():
|
143 |
+
batch_features = model.embed_waveform(batch['waveform'].cuda())
|
144 |
+
if cand_features is not None:
|
145 |
+
cand_features = torch.cat([cand_features, batch_features])
|
146 |
+
else:
|
147 |
+
cand_features = batch_features
|
148 |
+
cand_ids += batch['id']
|
149 |
+
cand_caps += batch['short_text']
|
150 |
+
progress.update()
|
151 |
+
progress.close()
|
152 |
+
my_features = cand_features
|
153 |
+
my_ids = cand_ids
|
154 |
+
# if test and model given
|
155 |
+
else:
|
156 |
+
# check if we already precomputed the embeddings
|
157 |
+
pickle_filename = 'nn_features.pkl'
|
158 |
+
if os.path.isfile(pickle_filename):
|
159 |
+
with open(pickle_filename, 'rb') as f:
|
160 |
+
(cand_features, cand_ids, cand_caps) = pickle.load(f)
|
161 |
+
else:
|
162 |
+
# build the features from the provided set instead of self
|
163 |
+
cand_features = None
|
164 |
+
cand_ids = []
|
165 |
+
cand_caps = []
|
166 |
+
temp_loader = DataLoader(candidate_set, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
|
167 |
+
progress = tqdm(total=len(temp_loader), dynamic_ncols=True)
|
168 |
+
for batch in temp_loader:
|
169 |
+
with torch.no_grad():
|
170 |
+
batch_features = model.embed_waveform(batch['waveform'].cuda())
|
171 |
+
if cand_features is not None:
|
172 |
+
cand_features = torch.cat([cand_features, batch_features])
|
173 |
+
else:
|
174 |
+
cand_features = batch_features
|
175 |
+
cand_ids += batch['id']
|
176 |
+
#cand_caps += [' '.join(x.split()[:10]) for x in batch['short_text']]
|
177 |
+
cand_caps += batch['short_text']
|
178 |
+
progress.update()
|
179 |
+
progress.close()
|
180 |
+
# dump to pickle so we don't have to redo this each time
|
181 |
+
with open(pickle_filename, 'wb') as f:
|
182 |
+
pickle.dump((cand_features, cand_ids, cand_caps), f)
|
183 |
+
# load up my own ids and features
|
184 |
+
my_features = None
|
185 |
+
my_ids = []
|
186 |
+
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
|
187 |
+
progress = tqdm(total=len(temp_loader), dynamic_ncols=True)
|
188 |
+
for batch in temp_loader:
|
189 |
+
with torch.no_grad():
|
190 |
+
batch_features = model.embed_waveform(batch['waveform'].cuda())
|
191 |
+
if my_features is not None:
|
192 |
+
my_features = torch.cat([my_features, batch_features])
|
193 |
+
else:
|
194 |
+
my_features = batch_features
|
195 |
+
my_ids += batch['id']
|
196 |
+
progress.update()
|
197 |
+
progress.close()
|
198 |
+
is_self_sim = my_ids == cand_ids
|
199 |
+
for idx,audio_id in tqdm(enumerate(my_ids), total=len(my_ids), dynamic_ncols=True):
|
200 |
+
features = my_features[idx]
|
201 |
+
similarities = features @ cand_features.T
|
202 |
+
# remove identical matches
|
203 |
+
if is_self_sim:
|
204 |
+
similarities[idx] = float('-inf')
|
205 |
+
best_idx = torch.argmax(similarities)
|
206 |
+
most_similar_caption = cand_caps[best_idx]
|
207 |
+
self.id2neighbor[my_ids[idx]] = most_similar_caption
|
208 |
+
|
209 |
+
def pad_tokens(self, tokens, tokens_tweet):
|
210 |
+
tweet_text_len = 0
|
211 |
+
if self.tweet_prefix:
|
212 |
+
tweet_text_len = tokens_tweet[:self.max_seq_len // 2].shape[0]
|
213 |
+
tokens = torch.cat((tokens_tweet[:tweet_text_len], tokens))
|
214 |
+
padding = self.max_seq_len - tokens.shape[0]
|
215 |
+
if padding > 0:
|
216 |
+
tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
|
217 |
+
elif padding < 0:
|
218 |
+
tokens = tokens[:self.max_seq_len]
|
219 |
+
mask = tokens.ge(0) # mask is zero where we out of sequence
|
220 |
+
tokens[~mask] = 0
|
221 |
+
mask = mask.float()
|
222 |
+
mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask
|
223 |
+
return tokens, mask, tweet_text_len
|
224 |
+
|
225 |
+
def read_wav(self, filename):
|
226 |
+
stem = PurePosixPath(filename).stem
|
227 |
+
picklefile = f'wt-{self.whole_track}-t-{self.train}-{stem}.pt'
|
228 |
+
picklepath = f'/trunk/datasets/nsrivats/audiostock_proc/{picklefile}'
|
229 |
+
if os.path.exists(picklepath):
|
230 |
+
y = torch.load(picklepath)
|
231 |
+
else:
|
232 |
+
# chunk
|
233 |
+
try:
|
234 |
+
num_frames = torchaudio.info(filename).num_frames
|
235 |
+
except:
|
236 |
+
return None
|
237 |
+
# make sure it wasn't empty, if so die
|
238 |
+
if num_frames == 0:
|
239 |
+
return None
|
240 |
+
sta = 0
|
241 |
+
if not self.whole_track:
|
242 |
+
if self.train:
|
243 |
+
sta = random.randint(0, num_frames - 441001)
|
244 |
+
else:
|
245 |
+
sta = (num_frames - 441001) // 2
|
246 |
+
num_frames = 441000
|
247 |
+
|
248 |
+
y, sr = torchaudio.load(filename, frame_offset=sta, num_frames=num_frames)
|
249 |
+
# resample
|
250 |
+
y = torchaudio.functional.resample(y, sr, 48000)
|
251 |
+
y = y[:, :441000]
|
252 |
+
# mono
|
253 |
+
y = y.mean(dim=0)
|
254 |
+
# normalize
|
255 |
+
y = int16_to_float32(float32_to_int16(y))
|
256 |
+
# save
|
257 |
+
torch.save(y, picklepath)
|
258 |
+
return y
|
259 |
+
|
260 |
+
def __getitem__(self, index):
|
261 |
+
idx = index % len(self.file_list)
|
262 |
+
data_dict = {}
|
263 |
+
f = self.file_list[idx]
|
264 |
+
lf = os.path.join(self.label_path, os.path.basename(f).split('.')[0] + '.json')
|
265 |
+
data_dict['waveform'] = self.read_wav(f)
|
266 |
+
if os.path.isfile(lf):
|
267 |
+
with open(lf,'r') as label_file:
|
268 |
+
label_data = json.load(label_file)
|
269 |
+
data_dict['id'] = label_data['id']
|
270 |
+
data_dict['short_text'] = label_data['short_text']
|
271 |
+
if self.normalize:
|
272 |
+
data_dict['short_text'] = ' '.join(muscaps_tokenize(data_dict['short_text']))
|
273 |
+
if 'long_text' in label_data and label_data['long_text'] is not None:
|
274 |
+
data_dict['long_text'] = label_data['long_text']
|
275 |
+
else:
|
276 |
+
data_dict['long_text'] = ''
|
277 |
+
'''
|
278 |
+
data_dict['tag'] = label_data['tag']
|
279 |
+
data_dict['impression'] = label_data['impression']
|
280 |
+
data_dict['purpose'] = label_data['purpose']
|
281 |
+
'''
|
282 |
+
else:
|
283 |
+
data_dict['id'] = os.path.basename(f).split('.')[0]
|
284 |
+
data_dict['short_text'] = ''
|
285 |
+
data_dict['long_text'] = ''
|
286 |
+
|
287 |
+
# tokenize the caption
|
288 |
+
caption_proc = preproc(data_dict['short_text'], self.tokenizer)
|
289 |
+
tokens = torch.tensor(caption_proc, dtype=torch.int64)
|
290 |
+
tweet_text = self.id2neighbor[data_dict['id']] if self.tweet_prefix else ''
|
291 |
+
tweet_proc = preproc(tweet_text, self.tokenizer, stop=False)
|
292 |
+
tokens_tweet = torch.tensor(tweet_proc, dtype=torch.int64)
|
293 |
+
tokens, mask, tweet_text_len = self.pad_tokens(tokens, tokens_tweet)
|
294 |
+
data_dict['tokens'] = tokens
|
295 |
+
data_dict['mask'] = mask
|
296 |
+
data_dict['tweet_text_len'] = tweet_text_len
|
297 |
+
data_dict['tweet_text'] = tweet_text
|
298 |
+
|
299 |
+
if (data_dict['id'] is None or
|
300 |
+
data_dict['short_text'] is None or
|
301 |
+
data_dict['long_text'] is None or
|
302 |
+
data_dict['tokens'] is None or
|
303 |
+
data_dict['mask'] is None or
|
304 |
+
data_dict['tweet_text_len'] is None or
|
305 |
+
data_dict['tweet_text'] is None or
|
306 |
+
data_dict['waveform'] is None
|
307 |
+
):
|
308 |
+
return None
|
309 |
+
else:
|
310 |
+
return data_dict
|
311 |
+
|
312 |
+
def __len__(self):
|
313 |
+
return self.total_len
|
314 |
+
|
315 |
+
class MusicCapsDataset(AudiostockDataset):
|
316 |
+
def __init__(self, dataset_path, args, train = True, split = None, factor = 1.0, whole_track = False, verbose=True, dedup=True):
|
317 |
+
super(AudiostockDataset, self).__init__()
|
318 |
+
# set up parameters
|
319 |
+
self.max_seq_len = 150
|
320 |
+
self.tweet_prefix = args.tweet_prefix
|
321 |
+
if self.tweet_prefix:
|
322 |
+
self.max_seq_len *= 2
|
323 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=True)
|
324 |
+
self.prefix_length = args.prefix_length
|
325 |
+
self.normalize = args.normalize
|
326 |
+
self.whole_track = whole_track
|
327 |
+
|
328 |
+
self.label_path = os.path.join(dataset_path, 'audio')
|
329 |
+
self.file_list = []
|
330 |
+
self.label_data = []
|
331 |
+
label_reader = pd.read_csv(f'{dataset_path}/musiccaps-resplit.csv')
|
332 |
+
for idx,row in label_reader.iterrows():
|
333 |
+
if (row['is_audioset_eval'] == 1 and split == 'musiccaps_eval') \
|
334 |
+
or (row['is_audioset_eval'] == 0 and split == 'musiccaps_train') \
|
335 |
+
or (row['is_audioset_eval'] == 2 and split == 'musiccaps_dev'):
|
336 |
+
data_dict = {}
|
337 |
+
data_dict['id'] = row['ytid']
|
338 |
+
self.file_list.append(f"{dataset_path}/audio/{data_dict['id']}.wav")
|
339 |
+
data_dict['short_text'] = row['caption']
|
340 |
+
if self.normalize:
|
341 |
+
data_dict['short_text'] = ' '.join(muscaps_tokenize(data_dict['short_text']))
|
342 |
+
data_dict['long_text'] = ''
|
343 |
+
data_dict['tag'] = row['aspect_list']
|
344 |
+
self.label_data.append(data_dict)
|
345 |
+
|
346 |
+
self.train = train
|
347 |
+
self.total_len = int(len(self.file_list) * factor)
|
348 |
+
if verbose:
|
349 |
+
print(f'Dataset Loaded | File Num.: {len(self.file_list)} | Batches per epoch: {self.total_len}')
|
350 |
+
|
351 |
+
def __getitem__(self, index):
|
352 |
+
idx = index % len(self.file_list)
|
353 |
+
data_dict = {}
|
354 |
+
f = self.file_list[idx]
|
355 |
+
data_dict['waveform'] = self.read_wav(f)
|
356 |
+
for k,v in self.label_data[idx].items():
|
357 |
+
data_dict[k] = v
|
358 |
+
|
359 |
+
# tokenize the caption
|
360 |
+
caption_proc = preproc(data_dict['short_text'], self.tokenizer)
|
361 |
+
tokens = torch.tensor(caption_proc, dtype=torch.int64)
|
362 |
+
tweet_text = self.id2neighbor[data_dict['id']] if self.tweet_prefix else ''
|
363 |
+
tweet_proc = preproc(tweet_text, self.tokenizer, stop=False)
|
364 |
+
tokens_tweet = torch.tensor(tweet_proc, dtype=torch.int64)
|
365 |
+
tokens, mask, tweet_text_len = self.pad_tokens(tokens, tokens_tweet)
|
366 |
+
data_dict['tokens'] = tokens
|
367 |
+
data_dict['mask'] = mask
|
368 |
+
data_dict['tweet_text_len'] = tweet_text_len
|
369 |
+
data_dict['tweet_text'] = tweet_text
|
370 |
+
|
371 |
+
if (data_dict['id'] is None or
|
372 |
+
data_dict['short_text'] is None or
|
373 |
+
data_dict['long_text'] is None or
|
374 |
+
data_dict['tokens'] is None or
|
375 |
+
data_dict['mask'] is None or
|
376 |
+
data_dict['tweet_text_len'] is None or
|
377 |
+
data_dict['tweet_text'] is None or
|
378 |
+
data_dict['waveform'] is None
|
379 |
+
):
|
380 |
+
return None
|
381 |
+
else:
|
382 |
+
return data_dict
|
dupes.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e83b71d63cd11dc8840b44bcea625d1c618c8b421e4c6ec6c65580af5109c7bd
|
3 |
+
size 1807022
|
infer.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from audiocaptioner import AudioCaptioner
|
2 |
+
from data_module import AudiostockDataset
|
3 |
+
from utils import *
|
4 |
+
|
5 |
+
def infer(input_filename):
|
6 |
+
device = get_device(0)
|
7 |
+
# connect to GCS
|
8 |
+
gcs = CheckpointManager()
|
9 |
+
# create and/or load model
|
10 |
+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=False)
|
11 |
+
prefix_dim = 512
|
12 |
+
prefix_length = 10
|
13 |
+
prefix_length_clip = 10
|
14 |
+
num_layers = 8
|
15 |
+
checkpoint = 'checkpoints/ZRIUE-BEST.pt'
|
16 |
+
model = AudioCaptioner(prefix_length, clip_length=prefix_length_clip, prefix_size=prefix_dim, num_layers=num_layers).to(device)
|
17 |
+
model.load_state_dict(gcs.get_checkpoint(checkpoint))
|
18 |
+
print(f'Loaded from {checkpoint}')
|
19 |
+
model.eval()
|
20 |
+
# read in the wav file and precompute neighbors
|
21 |
+
#dataset_path = '/graft1/datasets/kechen/audiostock-full'
|
22 |
+
dataset_path = ''
|
23 |
+
train_dataset = AudiostockDataset(
|
24 |
+
dataset_path=dataset_path,
|
25 |
+
train=False,
|
26 |
+
split='audiostock-train-240k.txt',
|
27 |
+
factor=1.0,
|
28 |
+
verbose=False,
|
29 |
+
file_list=open('audiostock-train-240k.txt', 'r').read().split()
|
30 |
+
)
|
31 |
+
print('Reading in file', input_filename)
|
32 |
+
dataset = AudiostockDataset(
|
33 |
+
dataset_path=dataset_path,
|
34 |
+
train=False,
|
35 |
+
split=None,
|
36 |
+
factor=1.0,
|
37 |
+
verbose=False,
|
38 |
+
file_list=[input_filename] # manually override file list
|
39 |
+
)
|
40 |
+
dataset.precompute_neighbors(model, candidate_set=train_dataset)
|
41 |
+
waveform = dataset.read_wav(input_filename).unsqueeze(0).to(device, dtype=torch.float32)
|
42 |
+
# predict
|
43 |
+
with torch.no_grad():
|
44 |
+
prefix_embed = model.create_prefix(waveform, 1)
|
45 |
+
tweet_tokens = torch.tensor(preproc(dataset.id2neighbor[os.path.basename(input_filename).split('.')[0]], tokenizer, stop=False), dtype=torch.int64).to(device)[:150]
|
46 |
+
tweet_embed = model.gpt.transformer.wte(tweet_tokens)
|
47 |
+
prefix_embed = torch.cat([prefix_embed, tweet_embed.unsqueeze(0)], dim=1)
|
48 |
+
candidates = generate_beam(model, tokenizer, embed=prefix_embed, beam_size=5)
|
49 |
+
generated_text = candidates[0]
|
50 |
+
generated_text = postproc(generated_text)
|
51 |
+
print('=======================================')
|
52 |
+
print(generated_text)
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
infer('../MusicCaptioning/sample_inputs/sisters.mp3')
|
lib.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import pandas as pd
|
6 |
+
import dill as pickle
|
7 |
+
pickle._dill._reverse_typemap['ClassType'] = type
|
8 |
+
import random
|
9 |
+
import string
|
10 |
+
import sys
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
import torch.nn.functional as nnf
|
14 |
+
from torch.utils.data import Dataset, DataLoader
|
15 |
+
from tqdm import tqdm
|
16 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
|
17 |
+
from typing import Tuple, List, Union, Optional
|
18 |
+
|
19 |
+
from clipcap import *
|
utils.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lib import *
|
2 |
+
from twokenize import tokenizeRawTweetText
|
3 |
+
import re
|
4 |
+
|
5 |
+
def muscaps_tokenize(raw):
|
6 |
+
raw = raw.lower()
|
7 |
+
for punc in string.punctuation:
|
8 |
+
raw = raw.replace(punc, ' ')
|
9 |
+
tokens = raw.split()
|
10 |
+
return tokens
|
11 |
+
|
12 |
+
def get_device(device_id: int) -> torch.device:
|
13 |
+
if not torch.cuda.is_available():
|
14 |
+
return torch.device('cpu')
|
15 |
+
device_id = min(torch.cuda.device_count() - 1, device_id)
|
16 |
+
return torch.device(f'cuda:{device_id}')
|
17 |
+
|
18 |
+
def preproc(caption, tokenizer, stop=True):
|
19 |
+
caption = caption.replace('.', '<STOP>')
|
20 |
+
caption_proc = tokenizer.encode(caption)
|
21 |
+
if stop:
|
22 |
+
caption_proc += tokenizer.encode('.')
|
23 |
+
return caption_proc
|
24 |
+
|
25 |
+
def postproc(caption):
|
26 |
+
caption = caption.replace('<STOP>', '.')
|
27 |
+
if caption[-1] == '.':
|
28 |
+
caption = caption[:-1]
|
29 |
+
return caption
|
30 |
+
|
31 |
+
class CheckpointManager:
|
32 |
+
def __init__(self):
|
33 |
+
self.checkpoint_dir = '/home/nsrivats/Repositories/MusicCaptioning/checkpoints'
|
34 |
+
|
35 |
+
def get_checkpoint(self, checkpoint):
|
36 |
+
with open(checkpoint, 'rb') as infile:
|
37 |
+
return torch.load(infile)
|
38 |
+
|
39 |
+
def save_checkpoint(self, state_dict, checkpoint):
|
40 |
+
filename = f'{self.checkpoint_dir}/{checkpoint}'
|
41 |
+
with open(filename, 'wb') as outfile:
|
42 |
+
torch.save(state_dict, outfile)
|
43 |
+
|
44 |
+
def save_logs(self, logdir):
|
45 |
+
pass
|