sino commited on
Commit
a62a2b1
1 Parent(s): 8d2d274

Upload modeling_maelm.py

Browse files
Files changed (1) hide show
  1. modeling_maelm.py +592 -0
modeling_maelm.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pdb
4
+ from mmcv.cnn.bricks import padding
5
+ import torch
6
+ from torch import nn, einsum
7
+ from typing import Optional, Dict, Tuple
8
+ from src.mae_vit import MAEViT
9
+ from src.htsat import HTSAT_Swin_Transformer, create_htsat_model
10
+ from src.LMdecoder import LMDecoder, LMDecoder_qlora
11
+ from src.vision_transformer import VisionTransformer
12
+ from einops import rearrange, repeat
13
+ from einops_exts import rearrange_many
14
+ import inspect
15
+
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from .configuration_maelm import MAELMConfig
18
+
19
+ class ArgsHandler:
20
+ def __init__(self, module, funcname, fargs, fkargs):
21
+ self.fargs = list(fargs)
22
+ self.fkargs = fkargs
23
+ func = getattr(module, funcname)
24
+ fal_repr = f"{funcname}_argnames_list"
25
+ if (argns_list:=getattr(module, fal_repr, None)) is None:
26
+ self.func_sig = inspect.signature(func)
27
+ self.argnames_list = list(self.func_sig.parameters.keys())
28
+ setattr(module, fal_repr, self.argnames_list)
29
+ else:
30
+ self.argnames_list = argns_list
31
+
32
+ def get_arg(self, arg_name):
33
+ if arg_name in self.fkargs:
34
+ arg = self.fkargs[arg_name]
35
+ else:
36
+ arg = self.fargs[self.argnames_list.index(arg_name)]
37
+ return arg
38
+
39
+ def set_arg(self, arg_name, arg_value):
40
+ if arg_name in self.fkargs:
41
+ self.fkargs[arg_name] = arg_value
42
+ else:
43
+ self.fargs[self.argnames_list.index(arg_name)] = arg_value
44
+
45
+ def return_all_args(self,):
46
+ return tuple(self.fargs), self.fkargs
47
+
48
+ class SquaredReLU(nn.Module):
49
+ """ squared ReLU activation function"""
50
+ def __init__(self):
51
+ super().__init__()
52
+
53
+ def forward(self, x):
54
+ return torch.pow(torch.relu(x), 2)
55
+
56
+ def FeedForward(dim, out_dim, mult=4, act='gelu'):
57
+ """
58
+ lucidrains implementation, slightly modified with the act parameter.
59
+ """
60
+
61
+ acts = dict(
62
+ gelu=nn.GELU,
63
+ sqrelu=SquaredReLU,
64
+ relu=nn.ReLU
65
+ )
66
+
67
+ assert act in acts, f"act. can only be one of {acts.keys()}"
68
+
69
+ inner_dim = int(dim * mult)
70
+ return nn.Sequential(
71
+ nn.LayerNorm(dim),
72
+ nn.Linear(dim, inner_dim, bias=False),
73
+ acts[act](),
74
+ nn.Linear(inner_dim, out_dim, bias=False)
75
+ )
76
+
77
+
78
+ class PerceiverAttentionLayer(nn.Module):
79
+ def __init__(
80
+ self,
81
+ *,
82
+ feat_dim,
83
+ latent_dim,
84
+ dim_head=64,
85
+ heads=8
86
+ ):
87
+ super().__init__()
88
+ self.scale = dim_head ** -0.5
89
+ self.heads = heads
90
+ self.dim_head = dim_head
91
+
92
+ inner_dim = dim_head * heads
93
+
94
+ # trainable components of PerceiverAttentionLayer
95
+ self.norm_media = nn.LayerNorm(feat_dim)
96
+ self.norm_latents = nn.LayerNorm(latent_dim)
97
+
98
+ self.to_q = nn.Linear(latent_dim, inner_dim, bias=False)
99
+ self.to_k = nn.Linear(feat_dim, inner_dim, bias=False)
100
+ self.to_v = nn.Linear(feat_dim, inner_dim, bias=False)
101
+ self.to_out = nn.Linear(inner_dim, latent_dim, bias=False)
102
+
103
+ def forward(self, features, latents):
104
+ """
105
+ Latent vectors are cross-attending to the visual features x.
106
+ :param x: Tensor (n_batch, n_features, dim)
107
+ visual features
108
+ :param latents: Tensor (n_batch, n_latents, dim)
109
+ latent learnt vectors from which the queries are computed.
110
+ Actually the same, just replicated in n_batch and n_frames dimension.
111
+ :return: Tensor (n_batch, n_latents, dim)
112
+ """
113
+ assert features.ndim == 3
114
+ assert latents.ndim == 3
115
+ assert features.shape[0] == latents.shape[0]
116
+ #assert features.shape[2] == latents.shape[2]
117
+
118
+ n_heads = self.heads
119
+ n_batch, n_features, dim = features.shape
120
+ n_queries = latents.shape[1]
121
+
122
+ # layer normalization, as usual
123
+ x = self.norm_media(features)
124
+ latents = self.norm_latents(latents)
125
+
126
+ # queries
127
+ # compute the queries from the latents, for all attention heads simultaneously.
128
+ q = self.to_q(latents)
129
+ q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads)
130
+ assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head])
131
+
132
+ # keys and values for all attention heads
133
+
134
+ '''
135
+ kv_input = torch.cat((x, latents), dim=-2)
136
+ n_features_latents = n_features + n_queries
137
+ '''
138
+
139
+ kv_input = x
140
+ n_features_latents = n_features
141
+
142
+ # keys, values
143
+ k = self.to_k(kv_input)
144
+ v = self.to_v(kv_input)
145
+ # batch, features, (heads, dim)
146
+
147
+ # split so we have an extra dimension for the heads
148
+ # q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h=h)
149
+ k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads)
150
+ assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head])
151
+
152
+ # scale queries?
153
+ q = q * self.scale
154
+
155
+ # attention
156
+
157
+ # attention scores
158
+ # sim = einsum('... i d, ... j d -> ... i j', q, k)
159
+ sim = einsum('b h q d, b h f d -> b h q f', q, k)
160
+
161
+ # Is this for numerical stability? Does not affect the result of the softmax operation
162
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
163
+ alphas = sim.softmax(dim=-1)
164
+
165
+ # out = einsum('... i j, ... j d -> ... i d', alphas, v)
166
+ out = einsum('b h q f, b h f v -> b h q v', alphas, v)
167
+
168
+ # out = rearrange(out, 'b h t n d -> b t n (h d)', h=h)
169
+ out = rearrange(out, 'b h q v -> b q (h v)')
170
+ return self.to_out(out)
171
+
172
+
173
+ class MAEForCausalLM(PreTrainedModel):
174
+ """
175
+
176
+ Args:
177
+ backbone (dict): Config dict for encoder. Defaults to None.
178
+ neck (dict): Config dict for encoder. Defaults to None.
179
+ head (dict): Config dict for loss functions. Defaults to None.
180
+ init_cfg (dict, optional): Config dict for weight initialization.
181
+ Defaults to None.
182
+ """
183
+
184
+ config_class = MAELMConfig
185
+
186
+ def __init__(self, config: MAELMConfig) -> None:
187
+ super().__init__(config)
188
+ backbone = config.backbone
189
+ assert backbone is not None
190
+ bk_name = backbone.pop('name')
191
+ self.bk_name = bk_name
192
+ if bk_name == 'MAEViT':
193
+ ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
194
+ self.backbone = MAEViT(**backbone)
195
+ if ckpt_path is not None:
196
+ ckpt = torch.load( ckpt_path,'cpu')
197
+ self.backbone.load_state_dict(ckpt['state_dict'])
198
+
199
+ elif bk_name == 'HTSAT':
200
+ ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
201
+ self.backbone = create_htsat_model(backbone)
202
+ if ckpt_path is not None:
203
+ ckpt = torch.load( ckpt_path,'cpu')
204
+ self.backbone.load_state_dict(ckpt['state_dict'])
205
+ elif bk_name == 'qformer':
206
+ raise NotImplemented
207
+ else:
208
+ raise NotImplemented
209
+
210
+
211
+
212
+ # neck["num_patches"] = self.backbone.num_patches
213
+ # neck["patch_resolution"] = self.backbone.patch_resolution
214
+ neck = config.neck
215
+ assert neck is not None
216
+ nk_name = neck.pop('name')
217
+ if nk_name == 'LMDecoder':
218
+ self.neck = LMDecoder(**neck)
219
+ elif nk_name == 'LMDecoder_qlora':
220
+ self.neck = LMDecoder_qlora(**neck)
221
+ else:
222
+ raise NotImplemented
223
+ self.config = self.neck.LMconfig # TODO
224
+
225
+ '''
226
+ self.ae_proj = nn.Linear(
227
+ 768, self.config.hidden_size
228
+ )
229
+ '''
230
+
231
+ ## TODO
232
+
233
+ #self.neck.lm.apply(lambda m:m.gradient_checkpointing=True)
234
+ self.neck.lm.model.gradient_checkpointing = False
235
+
236
+ self.register_buffer('ones', torch.ones((1,4096), dtype=torch.long), persistent=False)
237
+ self.graft_adapter()
238
+ self.init_weights()
239
+ # float32 --> bfloat16
240
+ for p in self.parameters():
241
+ p.data = p.data.to(torch.bfloat16)
242
+ if config.resume_from_checkpoint is not None:
243
+ drain_loader = True
244
+ accelerator.load_state(config.resume_from_checkpoint, load_module_strict=False)
245
+ # start_epoch, start_step, all_step = [int(_.split('_')[1]) for _ in args.resume_from_checkpoint.split('/')[-2].split('-')]
246
+ elif config.resume_from_pth is not None:
247
+ print(f'###########loading##########{config.resume_from_pth}###########loading##########')
248
+ ckpt = torch.load(config.resume_from_pth, map_location='cpu')
249
+ ckpt_copy = {k[7:]: v for k, v in ckpt.items()}
250
+ self.load_state_dict(ckpt_copy, strict=False)
251
+ print(f'###########loaded##########{config.resume_from_pth}###########loaded##########')
252
+
253
+ if False:
254
+ self.patch_llm()
255
+ self.first_run = True
256
+
257
+ def graft_adapter(self):
258
+ adapter_latent_len = 32
259
+ self.adapter_latent_len = adapter_latent_len
260
+ self.adapter_latent = nn.Parameter(torch.rand((1,adapter_latent_len, self.config.hidden_size), \
261
+ dtype=torch.float))
262
+ resampler_latent_len = 32
263
+ self.resampler_latent_len = resampler_latent_len
264
+ self.resampler_latent = nn.Parameter(torch.rand((1,resampler_latent_len, self.config.hidden_size), \
265
+ dtype=torch.float))
266
+ ## TODO
267
+ # self.adapter.pre_bn = torch.nn.BatchNorm1d(4096, affine=True)
268
+
269
+ self.adapter = nn.ModuleList([])
270
+
271
+ ff_mult = 4
272
+ heads=8
273
+ dim_head=512
274
+ act='gelu'
275
+
276
+ lm_dim = self.config.hidden_size
277
+ if self.bk_name == 'HTSAT':
278
+ feat_dim = 1024
279
+ depth = len(self.backbone.layers[2].blocks)
280
+ else:
281
+ feat_dim = 768
282
+ depth = int(len(self.neck.lm.model.layers)/2) # 16
283
+ for idx in range(depth):
284
+ self.adapter.append(nn.ModuleList([
285
+ Adapter(input_size=self.config.hidden_size),
286
+ # PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=dim_head, heads=heads),
287
+ # FeedForward(dim=lm_dim, out_dim=lm_dim, mult=1, act=act),
288
+ #FeedForward(dim=self.dim, out_dim=768, mult=ff_mult, act=act) if idx != depth-1 else nn.Identity()
289
+ ]))
290
+
291
+ self.samplers = nn.ModuleList([]) # add
292
+ for _ in range(3):
293
+ self.samplers.append(nn.ModuleList([
294
+ PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=64, heads=heads),
295
+ FeedForward(dim=lm_dim, out_dim=lm_dim, mult=4),
296
+ ]))
297
+ self.norm = nn.LayerNorm(lm_dim)
298
+
299
+ # self.agate_list = nn.ParameterList([])
300
+ # for i in range(len(self.neck.lm.model.layers)):
301
+ # self.agate_list.append(nn.Parameter(torch.zeros(lm_dim)))
302
+
303
+
304
+
305
+ def init_weights(self):
306
+ try:
307
+ super().init_weights()
308
+ except:
309
+ pass
310
+ # import traceback
311
+ # traceback.print_exc()
312
+ if getattr(self, 'adapter_latent', None) is not None:
313
+ self.adapter_latent.data.normal_(mean=0.0, std=0.02)
314
+ if getattr(self, 'resampler_latent', None) is not None:
315
+ self.adapter_latent.data.normal_(mean=0.0, std=0.02)
316
+
317
+ def forward_resampler(self, x):
318
+ # b, 768, 512
319
+ latents = repeat(self.resampler_latent, 'b n d -> (bs b) n d', bs=x.shape[0])
320
+ for attn, ff in self.samplers:
321
+ latents = attn(x, latents) + latents
322
+ latents = ff(latents) + latents
323
+ v2t_feats = self.norm(latents) #
324
+ # v2t_atts = torch.ones(v2t_feats.shape[:2], dtype=torch.long, device=v2t_feats.device)
325
+ return v2t_feats # bs, 32, dim_llm
326
+
327
+
328
+ def hook_adapter(self, audio_embedding, lm, v2t_feats):
329
+
330
+ class PHooker:
331
+ # model = self.backbone
332
+ # mgtr = self.backbone.forward_generator(spectrogram)
333
+ adapter = self.adapter
334
+ y = v2t_feats
335
+ handles_list = list()
336
+ cnter = 0
337
+ def layer_prehook(self, m, margs, mkargs):
338
+ ahl = ArgsHandler(m, 'forward', margs, mkargs)
339
+
340
+ # print(self.cnter)
341
+
342
+ # if self.cnter>=16:
343
+ # self.cnter+=1
344
+ # return None
345
+ adapt = self.adapter[self.cnter][0]
346
+
347
+ hs = ahl.get_arg("hidden_states")
348
+ adapter_residual = hs
349
+ neo_hs = adapt(hs, adapter_residual)
350
+
351
+ self.cnter+=1
352
+ ahl.set_arg("hidden_states", neo_hs)
353
+ return ahl.return_all_args()
354
+ def first_layer_prehook(self, m, margs, mkargs):
355
+ ahl = ArgsHandler(m, 'forward', margs, mkargs)
356
+ neo_lm_latents = self.y # torch.Size([128, 32, 4096])
357
+ hs = ahl.get_arg("hidden_states") # torch.Size([128, 87, 4096])
358
+ hs_msk = self.lm_ahl.get_arg("input_ids") < 0 # torch.Size([128, 87]) [False,, True*32, False,,]
359
+ # __import__('pdb').set_trace()
360
+ neo_hs = hs.masked_scatter(hs_msk.unsqueeze(-1), neo_lm_latents) # resampler hooker直接替换
361
+ ahl.set_arg("hidden_states", neo_hs)
362
+ return ahl.return_all_args()
363
+
364
+ def lm_prehook(self, m, margs, mkargs):
365
+ self.lm_ahl = ArgsHandler(m, 'forward', margs, mkargs)
366
+ return None
367
+ def last_layer_hook(self, m, margs, mkargs):
368
+ # __import__('pdb').set_trace()
369
+ self.cnter = 0
370
+
371
+ if getattr(lm,'phooker',False):
372
+ for _ in lm.phooker.handles_list:
373
+ _.remove()
374
+ del lm.phooker
375
+ lm.phooker = None
376
+ phooker = PHooker()
377
+ phooker.handles_list.append(lm.register_forward_pre_hook(phooker.lm_prehook, with_kwargs=True))
378
+ # 第一层插入
379
+ phooker.handles_list.append(lm.model.layers[0].register_forward_pre_hook(phooker.first_layer_prehook, with_kwargs=True))
380
+
381
+ for ii in range(1,len(lm.model.layers),2):
382
+ l = lm.model.layers[ii]
383
+ handle = l.register_forward_pre_hook(phooker.layer_prehook, with_kwargs=True)
384
+ phooker.handles_list.append(handle)
385
+ phooker.handles_list.append(lm.model.layers[-1].register_forward_pre_hook(phooker.last_layer_hook, with_kwargs=True))
386
+ lm.phooker = phooker
387
+ return None
388
+
389
+
390
+
391
+ def prepare_ids(self, batch, audio_ids):
392
+ toker = self.neck.tokenizer
393
+ # for idx, l in enumerate(self.neck.lm.model.layers):
394
+ # l.agate = self.agate_list[idx].clone() ## should clone the parameter
395
+
396
+ with torch.no_grad():
397
+
398
+ input_ids = batch['input_ids']
399
+ att_msk = batch['attention_mask']
400
+ au_crds = batch['audio_crds']
401
+ ans_crds = batch['ans_crds']
402
+ bsz = input_ids.shape[0]
403
+ # __import__('pdb').set_trace()
404
+ ## TODO
405
+ merged_ids, merged_msk, label_ids = list(), list(), list()
406
+ for i in range(bsz):
407
+ # cur_merged_ids = torch.cat([input_ids[i,:au_crds[i]], -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]])
408
+ cur_merged_ids = torch.cat([ -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]])
409
+
410
+ # cur_au_msk = self.ones[:,:audio_ids.shape[1]][0].clone().type_as(att_msk).detach()
411
+ cur_au_msk = torch.ones(audio_ids.shape[1], device=audio_ids.device)
412
+ # cur_merged_msk = torch.cat([att_msk[i,:au_crds[i]], cur_au_msk, att_msk[i,au_crds[i]:]])
413
+ cur_merged_msk = torch.cat([ cur_au_msk, att_msk[i,au_crds[i]:]])
414
+ cur_label_ids = cur_merged_ids.clone().detach()
415
+ cur_label_ids[:audio_ids.shape[1]+ans_crds[i]] = -100
416
+
417
+ merged_ids.append(cur_merged_ids)
418
+ merged_msk.append(cur_merged_msk)
419
+ label_ids.append(cur_label_ids)
420
+
421
+ merged_ids = torch.stack(merged_ids, dim=0)
422
+ merged_msk = torch.stack(merged_msk, dim=0)
423
+ label_ids = torch.stack(label_ids, dim=0)
424
+
425
+ assert merged_ids.shape[0] == bsz
426
+ assert merged_ids.shape == merged_msk.shape
427
+
428
+ label_msk = merged_msk.clone()
429
+ assert label_msk.shape == merged_msk.shape
430
+ assert merged_msk[:,-1].max() == 1
431
+
432
+ for i in range(len(ans_crds)):
433
+ label_ids[i,:audio_ids.shape[1]+ans_crds[i]].fill_(-100)
434
+
435
+
436
+ merged_labels = label_ids
437
+ merged_ids[merged_ids.eq(-100)] = toker.pad_token_id
438
+
439
+ return merged_ids, merged_msk, merged_labels
440
+
441
+ def forward(self, batch, **kwargs):
442
+ """Forward computation during training.
443
+
444
+ Args:
445
+ img (torch.Tensor): Input images of shape (N, C, H, W).
446
+ kwargs: Any keyword arguments to be used to forward.
447
+ Returns:
448
+ Dict[str, torch.Tensor]: A dictionary of loss components.
449
+ """
450
+ bsz = len(batch['input_ids'])
451
+ device = batch['input_ids'].device
452
+ float_type = next(self.parameters()).dtype
453
+ spectrogram = batch['spectrogram'].type(float_type)
454
+ audio_embedding = self.backbone(spectrogram).detach() # b, 768, 512
455
+ resampler_feats = self.forward_resampler(audio_embedding)
456
+ self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) # add hook
457
+
458
+ # self.hook_resapmler(resampler_feats, self.neck.lm)
459
+
460
+ audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device)
461
+ assert audio_ids.max() < 100
462
+ merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids)
463
+
464
+ try:
465
+ assert merged_ids.shape == merged_labels.shape
466
+ outs = self.neck(input_ids=merged_ids.contiguous().long(),
467
+ flatten_embs=self.adapter_latent.flatten(0,1), # 32, 4096
468
+ # flatten_embs = resampler_feats.flatten(0,1), # b, 32, 4096
469
+ attention_mask=merged_msk.contiguous().long(),
470
+ labels=merged_labels.contiguous().long(), use_cache=False)
471
+ except Exception as e:
472
+ import traceback
473
+ traceback.print_exc()
474
+ __import__('remote_pdb').set_trace()
475
+ #outs.hidden_logits = self.hidden_logits
476
+
477
+ ## TODO
478
+ if eval(os.environ.get("doing_eval", 'False')):
479
+ outs.merged_ids = merged_ids.cpu()
480
+ outs.merged_labels = merged_labels.cpu()
481
+
482
+ return outs
483
+
484
+
485
+ def forward_test(self, batch, **kwargs):
486
+ """Forward computation during training.
487
+
488
+ Args:
489
+ img (torch.Tensor): Input images of shape (N, C, H, W).
490
+ kwargs: Any keyword arguments to be used to forward.
491
+ Returns:
492
+ Dict[str, torch.Tensor]: A dictionary of loss components.
493
+ """
494
+
495
+
496
+ bsz = len(batch['input_ids'])
497
+ device = batch['input_ids'].device
498
+ float_type = next(self.parameters()).dtype
499
+ spectrogram = batch['spectrogram'].type(float_type)
500
+ audio_embedding = self.backbone(spectrogram).detach() # b, 768, 512
501
+ resampler_feats = self.forward_resampler(audio_embedding)
502
+ self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) # add hook
503
+ # self.extract_features(batch, self.neck.lm)
504
+ audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device)
505
+ assert audio_ids.max() < 100
506
+
507
+ merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids)
508
+ au_crds = batch['audio_crds']
509
+ ans_crds = batch['ans_crds']
510
+
511
+ aid_len = audio_ids.shape[-1]
512
+
513
+
514
+ toker = self.neck.tokenizer
515
+ with torch.no_grad():
516
+
517
+ ## TODO
518
+ pad_token = toker.encode(self.neck.tokenizer.eos_token)[0]
519
+ padded_merged_ids = self.ones[:, :aid_len+max(ans_crds)].repeat(bsz, 1).clone().detach() * pad_token
520
+ for i in range(bsz):
521
+ # for i in range(1):
522
+ assert au_crds[i] <= ans_crds[i]
523
+ cur_ids = merged_ids[i][:aid_len+ans_crds[i]]
524
+ padded_merged_ids[i][max(ans_crds)-ans_crds[i]:] = cur_ids
525
+ # __import__('pdb').set_trace()
526
+ outs = self.neck.generate(padded_merged_ids, self.adapter_latent.flatten(0,1))
527
+ #outs.hidden_logits = self.hidden_logits
528
+
529
+ return outs
530
+
531
+
532
+
533
+ import torch
534
+ from torch import nn
535
+
536
+ from transformers.activations import ACT2FN
537
+
538
+ class Adapter(nn.Module):
539
+ """
540
+ Implementation of a sequential bottleneck adapter block.
541
+ """
542
+ def __init__(
543
+ self,
544
+ input_size,
545
+ down_sample=None,
546
+ ):
547
+ super().__init__()
548
+
549
+ self.input_size = input_size
550
+
551
+ # if a downsample size is not passed, we just half the size of the original input
552
+ self.down_sample = down_sample
553
+ if down_sample is None:
554
+ self.down_sample = self.input_size // 2
555
+
556
+ self.adapter_norm_before = nn.LayerNorm(self.input_size)
557
+ self.adapter_down = nn.Linear(self.input_size, self.down_sample)
558
+ self.non_linearity = ACT2FN["silu"]
559
+
560
+ # Up projection to input size
561
+ self.adapter_up = nn.Linear(self.down_sample, self.input_size)
562
+
563
+ # Additional scaling factor (from He et al. (2021))
564
+ self.scaling = nn.Parameter(torch.ones(1))
565
+
566
+ self.adapter_down.apply(self._init_weights)
567
+ self.adapter_up.apply(self._init_weights)
568
+
569
+ def forward(self, x, residual_input): # , residual_input=None):
570
+
571
+ down = self.non_linearity(self.adapter_down(self.adapter_norm_before(x)))
572
+
573
+ up = self.adapter_up(down)
574
+ up = up * self.scaling
575
+ output = up
576
+
577
+ output = output + residual_input
578
+
579
+ return output
580
+
581
+ @staticmethod
582
+ def _init_weights(module):
583
+ """Initialize the weights."""
584
+ if isinstance(module, (nn.Linear, nn.Embedding)):
585
+ # std defaults to 0.02, this might need to be changed
586
+ module.weight.data.normal_(mean=0.0, std=0.02)
587
+ elif isinstance(module, nn.LayerNorm):
588
+ module.bias.data.zero_()
589
+ module.weight.data.fill_(1.0)
590
+ if isinstance(module, nn.Linear) and module.bias is not None:
591
+ module.bias.data.zero_()
592
+