euiyulsong
commited on
Commit
•
eac156b
1
Parent(s):
7368bdd
Create fid.py
Browse files
fid.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import types
|
8 |
+
import torch
|
9 |
+
import transformers
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import CrossEntropyLoss
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
class FiDT5(transformers.T5ForConditionalGeneration):
|
16 |
+
def __init__(self, config):
|
17 |
+
super().__init__(config)
|
18 |
+
self.wrap_encoder()
|
19 |
+
|
20 |
+
def forward_(self, **kwargs):
|
21 |
+
if 'input_ids' in kwargs:
|
22 |
+
kwargs['input_ids'] = kwargs['input_ids'].view(kwargs['input_ids'].size(0), -1)
|
23 |
+
if 'attention_mask' in kwargs:
|
24 |
+
kwargs['attention_mask'] = kwargs['attention_mask'].view(kwargs['attention_mask'].size(0), -1)
|
25 |
+
|
26 |
+
return super(FiDT5, self).forward(
|
27 |
+
**kwargs
|
28 |
+
)
|
29 |
+
|
30 |
+
# We need to resize as B x (N * L) instead of (B * N) x L here
|
31 |
+
# because the T5 forward method uses the input tensors to infer
|
32 |
+
# dimensions used in the decoder.
|
33 |
+
# EncoderWrapper resizes the inputs as (B * N) x L.
|
34 |
+
def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
35 |
+
if input_ids != None:
|
36 |
+
# inputs might have already be resized in the generate method
|
37 |
+
if input_ids.dim() == 3:
|
38 |
+
self.encoder.n_passages = input_ids.size(1)
|
39 |
+
input_ids = input_ids.view(input_ids.size(0), -1)
|
40 |
+
if attention_mask != None:
|
41 |
+
attention_mask = attention_mask.view(attention_mask.size(0), -1)
|
42 |
+
return super().forward(
|
43 |
+
input_ids=input_ids,
|
44 |
+
attention_mask=attention_mask,
|
45 |
+
**kwargs
|
46 |
+
)
|
47 |
+
|
48 |
+
# We need to resize the inputs here, as the generate method expect 2D tensors
|
49 |
+
def generate(self, input_ids, attention_mask, max_length):
|
50 |
+
self.encoder.n_passages = input_ids.size(1)
|
51 |
+
return super().generate(
|
52 |
+
input_ids=input_ids.view(input_ids.size(0), -1),
|
53 |
+
attention_mask=attention_mask.view(attention_mask.size(0), -1),
|
54 |
+
max_length=max_length
|
55 |
+
)
|
56 |
+
|
57 |
+
def wrap_encoder(self, use_checkpoint=False):
|
58 |
+
"""
|
59 |
+
Wrap T5 encoder to obtain a Fusion-in-Decoder model.
|
60 |
+
"""
|
61 |
+
self.encoder = EncoderWrapper(self.encoder, use_checkpoint=use_checkpoint)
|
62 |
+
|
63 |
+
def unwrap_encoder(self):
|
64 |
+
"""
|
65 |
+
Unwrap Fusion-in-Decoder encoder, useful to load T5 weights.
|
66 |
+
"""
|
67 |
+
self.encoder = self.encoder.encoder
|
68 |
+
block = []
|
69 |
+
for mod in self.encoder.block:
|
70 |
+
block.append(mod.module)
|
71 |
+
block = nn.ModuleList(block)
|
72 |
+
self.encoder.block = block
|
73 |
+
|
74 |
+
def load_t5(self, state_dict):
|
75 |
+
self.unwrap_encoder()
|
76 |
+
self.load_state_dict(state_dict)
|
77 |
+
self.wrap_encoder()
|
78 |
+
|
79 |
+
def set_checkpoint(self, use_checkpoint):
|
80 |
+
"""
|
81 |
+
Enable or disable checkpointing in the encoder.
|
82 |
+
See https://pytorch.org/docs/stable/checkpoint.html
|
83 |
+
"""
|
84 |
+
for mod in self.encoder.encoder.block:
|
85 |
+
mod.use_checkpoint = use_checkpoint
|
86 |
+
|
87 |
+
def reset_score_storage(self):
|
88 |
+
"""
|
89 |
+
Reset score storage, only used when cross-attention scores are saved
|
90 |
+
to train a retriever.
|
91 |
+
"""
|
92 |
+
for mod in self.decoder.block:
|
93 |
+
mod.layer[1].EncDecAttention.score_storage = None
|
94 |
+
|
95 |
+
def get_crossattention_scores(self, context_mask):
|
96 |
+
"""
|
97 |
+
Cross-attention scores are aggregated to obtain a single scalar per
|
98 |
+
passage. This scalar can be seen as a similarity score between the
|
99 |
+
question and the input passage. It is obtained by averaging the
|
100 |
+
cross-attention scores obtained on the first decoded token over heads,
|
101 |
+
layers, and tokens of the input passage.
|
102 |
+
More details in Distilling Knowledge from Reader to Retriever:
|
103 |
+
https://arxiv.org/abs/2012.04584.
|
104 |
+
"""
|
105 |
+
scores = []
|
106 |
+
n_passages = context_mask.size(1)
|
107 |
+
for mod in self.decoder.block:
|
108 |
+
scores.append(mod.layer[1].EncDecAttention.score_storage)
|
109 |
+
scores = torch.cat(scores, dim=2)
|
110 |
+
bsz, n_heads, n_layers, _ = scores.size()
|
111 |
+
# batch_size, n_head, n_layers, n_passages, text_maxlength
|
112 |
+
scores = scores.view(bsz, n_heads, n_layers, n_passages, -1)
|
113 |
+
scores = scores.masked_fill(~context_mask[:, None, None], 0.)
|
114 |
+
scores = scores.sum(dim=[1, 2, 4])
|
115 |
+
ntokens = context_mask.sum(dim=[2]) * n_layers * n_heads
|
116 |
+
scores = scores/ntokens
|
117 |
+
return scores
|
118 |
+
|
119 |
+
def overwrite_forward_crossattention(self):
|
120 |
+
"""
|
121 |
+
Replace cross-attention forward function, only used to save
|
122 |
+
cross-attention scores.
|
123 |
+
"""
|
124 |
+
for mod in self.decoder.block:
|
125 |
+
attn = mod.layer[1].EncDecAttention
|
126 |
+
attn.forward = types.MethodType(cross_attention_forward, attn)
|
127 |
+
|
128 |
+
class EncoderWrapper(torch.nn.Module):
|
129 |
+
"""
|
130 |
+
Encoder Wrapper for T5 Wrapper to obtain a Fusion-in-Decoder model.
|
131 |
+
"""
|
132 |
+
def __init__(self, encoder, use_checkpoint=False):
|
133 |
+
super().__init__()
|
134 |
+
|
135 |
+
self.encoder = encoder
|
136 |
+
apply_checkpoint_wrapper(self.encoder, use_checkpoint)
|
137 |
+
|
138 |
+
def forward(self, input_ids=None, attention_mask=None, **kwargs,):
|
139 |
+
# total_length = n_passages * passage_length
|
140 |
+
bsz, total_length = input_ids.shape
|
141 |
+
passage_length = total_length // self.n_passages
|
142 |
+
input_ids = input_ids.view(bsz*self.n_passages, passage_length)
|
143 |
+
attention_mask = attention_mask.view(bsz*self.n_passages, passage_length)
|
144 |
+
outputs = self.encoder(input_ids, attention_mask, **kwargs)
|
145 |
+
outputs = (outputs[0].view(bsz, self.n_passages*passage_length, -1), ) + outputs[1:]
|
146 |
+
return outputs
|
147 |
+
|
148 |
+
class CheckpointWrapper(torch.nn.Module):
|
149 |
+
"""
|
150 |
+
Wrapper replacing None outputs by empty tensors, which allows the use of
|
151 |
+
checkpointing.
|
152 |
+
"""
|
153 |
+
def __init__(self, module, use_checkpoint=False):
|
154 |
+
super().__init__()
|
155 |
+
self.module = module
|
156 |
+
self.use_checkpoint = use_checkpoint
|
157 |
+
|
158 |
+
def forward(self, hidden_states, attention_mask, position_bias, **kwargs):
|
159 |
+
if self.use_checkpoint and self.training:
|
160 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
161 |
+
def custom_forward(*inputs):
|
162 |
+
output = self.module(*inputs, **kwargs)
|
163 |
+
empty = torch.tensor(
|
164 |
+
[],
|
165 |
+
dtype=torch.float,
|
166 |
+
device=output[0].device,
|
167 |
+
requires_grad=True)
|
168 |
+
output = tuple(x if x is not None else empty for x in output)
|
169 |
+
return output
|
170 |
+
|
171 |
+
output = torch.utils.checkpoint.checkpoint(
|
172 |
+
custom_forward,
|
173 |
+
hidden_states,
|
174 |
+
attention_mask,
|
175 |
+
position_bias
|
176 |
+
)
|
177 |
+
output = tuple(x if x.size() != 0 else None for x in output)
|
178 |
+
else:
|
179 |
+
output = self.module(hidden_states, attention_mask, position_bias, **kwargs)
|
180 |
+
return output
|
181 |
+
|
182 |
+
def apply_checkpoint_wrapper(t5stack, use_checkpoint):
|
183 |
+
"""
|
184 |
+
Wrap each block of the encoder to enable checkpointing.
|
185 |
+
"""
|
186 |
+
block = []
|
187 |
+
for mod in t5stack.block:
|
188 |
+
wrapped_mod = CheckpointWrapper(mod, use_checkpoint)
|
189 |
+
block.append(wrapped_mod)
|
190 |
+
block = nn.ModuleList(block)
|
191 |
+
t5stack.block = block
|
192 |
+
|
193 |
+
def cross_attention_forward(
|
194 |
+
self,
|
195 |
+
input,
|
196 |
+
mask=None,
|
197 |
+
kv=None,
|
198 |
+
position_bias=None,
|
199 |
+
past_key_value_state=None,
|
200 |
+
head_mask=None,
|
201 |
+
query_length=None,
|
202 |
+
use_cache=False,
|
203 |
+
output_attentions=False,
|
204 |
+
):
|
205 |
+
"""
|
206 |
+
This only works for computing cross attention over the input
|
207 |
+
"""
|
208 |
+
assert(kv != None)
|
209 |
+
assert(head_mask == None)
|
210 |
+
assert(position_bias != None or self.has_relative_attention_bias)
|
211 |
+
|
212 |
+
bsz, qlen, dim = input.size()
|
213 |
+
n_heads, d_heads = self.n_heads, self.d_kv
|
214 |
+
klen = kv.size(1)
|
215 |
+
|
216 |
+
q = self.q(input).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
|
217 |
+
if past_key_value_state == None:
|
218 |
+
k = self.k(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
|
219 |
+
v = self.v(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2)
|
220 |
+
else:
|
221 |
+
k, v = past_key_value_state
|
222 |
+
|
223 |
+
scores = torch.einsum("bnqd,bnkd->bnqk", q, k)
|
224 |
+
|
225 |
+
if mask is not None:
|
226 |
+
scores += mask
|
227 |
+
|
228 |
+
if position_bias is None:
|
229 |
+
position_bias = self.compute_bias(qlen, klen)
|
230 |
+
scores += position_bias
|
231 |
+
|
232 |
+
if self.score_storage is None:
|
233 |
+
self.score_storage = scores
|
234 |
+
|
235 |
+
attn = F.softmax(scores.float(), dim=-1).type_as(scores)
|
236 |
+
attn = F.dropout(attn, p=self.dropout, training=self.training)
|
237 |
+
|
238 |
+
output = torch.matmul(attn, v)
|
239 |
+
output = output.transpose(1, 2).contiguous().view(bsz, -1, self.inner_dim)
|
240 |
+
output = self.o(output)
|
241 |
+
|
242 |
+
if use_cache:
|
243 |
+
output = (output,) + ((k, v),)
|
244 |
+
else:
|
245 |
+
output = (output,) + (None,)
|
246 |
+
|
247 |
+
if output_attentions:
|
248 |
+
output = output + (attn,)
|
249 |
+
|
250 |
+
if self.has_relative_attention_bias:
|
251 |
+
output = output + (position_bias,)
|
252 |
+
|
253 |
+
return output
|
254 |
+
|
255 |
+
class RetrieverConfig(transformers.BertConfig):
|
256 |
+
|
257 |
+
def __init__(self,
|
258 |
+
indexing_dimension=768,
|
259 |
+
apply_question_mask=False,
|
260 |
+
apply_passage_mask=False,
|
261 |
+
extract_cls=False,
|
262 |
+
passage_maxlength=200,
|
263 |
+
question_maxlength=40,
|
264 |
+
projection=True,
|
265 |
+
**kwargs):
|
266 |
+
super().__init__(**kwargs)
|
267 |
+
self.indexing_dimension = indexing_dimension
|
268 |
+
self.apply_question_mask = apply_question_mask
|
269 |
+
self.apply_passage_mask = apply_passage_mask
|
270 |
+
self.extract_cls=extract_cls
|
271 |
+
self.passage_maxlength = passage_maxlength
|
272 |
+
self.question_maxlength = question_maxlength
|
273 |
+
self.projection = projection
|
274 |
+
|
275 |
+
class Retriever(transformers.PreTrainedModel):
|
276 |
+
|
277 |
+
config_class = RetrieverConfig
|
278 |
+
base_model_prefix = "retriever"
|
279 |
+
|
280 |
+
def __init__(self, config, initialize_wBERT=False):
|
281 |
+
super().__init__(config)
|
282 |
+
assert config.projection or config.indexing_dimension == 768, \
|
283 |
+
'If no projection then indexing dimension must be equal to 768'
|
284 |
+
self.config = config
|
285 |
+
if initialize_wBERT:
|
286 |
+
self.model = transformers.BertModel.from_pretrained('bert-base-uncased')
|
287 |
+
else:
|
288 |
+
self.model = transformers.BertModel(config)
|
289 |
+
if self.config.projection:
|
290 |
+
self.proj = nn.Linear(
|
291 |
+
self.model.config.hidden_size,
|
292 |
+
self.config.indexing_dimension
|
293 |
+
)
|
294 |
+
self.norm = nn.LayerNorm(self.config.indexing_dimension)
|
295 |
+
self.loss_fct = torch.nn.KLDivLoss()
|
296 |
+
|
297 |
+
def forward(self,
|
298 |
+
question_ids,
|
299 |
+
question_mask,
|
300 |
+
passage_ids,
|
301 |
+
passage_mask,
|
302 |
+
gold_score=None):
|
303 |
+
question_output = self.embed_text(
|
304 |
+
text_ids=question_ids,
|
305 |
+
text_mask=question_mask,
|
306 |
+
apply_mask=self.config.apply_question_mask,
|
307 |
+
extract_cls=self.config.extract_cls,
|
308 |
+
)
|
309 |
+
bsz, n_passages, plen = passage_ids.size()
|
310 |
+
passage_ids = passage_ids.view(bsz * n_passages, plen)
|
311 |
+
passage_mask = passage_mask.view(bsz * n_passages, plen)
|
312 |
+
passage_output = self.embed_text(
|
313 |
+
text_ids=passage_ids,
|
314 |
+
text_mask=passage_mask,
|
315 |
+
apply_mask=self.config.apply_passage_mask,
|
316 |
+
extract_cls=self.config.extract_cls,
|
317 |
+
)
|
318 |
+
|
319 |
+
score = torch.einsum(
|
320 |
+
'bd,bid->bi',
|
321 |
+
question_output,
|
322 |
+
passage_output.view(bsz, n_passages, -1)
|
323 |
+
)
|
324 |
+
score = score / np.sqrt(question_output.size(-1))
|
325 |
+
if gold_score is not None:
|
326 |
+
loss = self.kldivloss(score, gold_score)
|
327 |
+
else:
|
328 |
+
loss = None
|
329 |
+
|
330 |
+
return question_output, passage_output, score, loss
|
331 |
+
|
332 |
+
def embed_text(self, text_ids, text_mask, apply_mask=False, extract_cls=False):
|
333 |
+
text_output = self.model(
|
334 |
+
input_ids=text_ids,
|
335 |
+
attention_mask=text_mask if apply_mask else None
|
336 |
+
)
|
337 |
+
if type(text_output) is not tuple:
|
338 |
+
text_output.to_tuple()
|
339 |
+
text_output = text_output[0]
|
340 |
+
if self.config.projection:
|
341 |
+
text_output = self.proj(text_output)
|
342 |
+
text_output = self.norm(text_output)
|
343 |
+
|
344 |
+
if extract_cls:
|
345 |
+
text_output = text_output[:, 0]
|
346 |
+
else:
|
347 |
+
if apply_mask:
|
348 |
+
text_output = text_output.masked_fill(~text_mask[:, :, None], 0.)
|
349 |
+
text_output = torch.sum(text_output, dim=1) / torch.sum(text_mask, dim=1)[:, None]
|
350 |
+
else:
|
351 |
+
text_output = torch.mean(text_output, dim=1)
|
352 |
+
return text_output
|
353 |
+
|
354 |
+
def kldivloss(self, score, gold_score):
|
355 |
+
gold_score = torch.softmax(gold_score, dim=-1)
|
356 |
+
score = torch.nn.functional.log_softmax(score, dim=-1)
|
357 |
+
return self.loss_fct(score, gold_score)
|