Upload model
Browse files- config.json +3 -0
- modelling_cxrrg.py +10 -14
config.json
CHANGED
@@ -74,6 +74,9 @@
|
|
74 |
2
|
75 |
],
|
76 |
"sep_token_id": null,
|
|
|
|
|
|
|
77 |
"suppress_tokens": null,
|
78 |
"task_specific_params": null,
|
79 |
"temperature": 1.0,
|
|
|
74 |
2
|
75 |
],
|
76 |
"sep_token_id": null,
|
77 |
+
"separator_token_ids": [
|
78 |
+
3
|
79 |
+
],
|
80 |
"suppress_tokens": null,
|
81 |
"task_specific_params": null,
|
82 |
"temperature": 1.0,
|
modelling_cxrrg.py
CHANGED
@@ -200,7 +200,6 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
200 |
def prepare_inputs_for_generation(
|
201 |
self,
|
202 |
input_ids,
|
203 |
-
special_token_ids,
|
204 |
past_key_values=None,
|
205 |
use_cache=None,
|
206 |
encoder_outputs=None,
|
@@ -226,7 +225,7 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
226 |
# `inputs_embeds` are only to be used in the 1st generation step:
|
227 |
inputs_embeds = torch.cat([encoder_outputs[0], self.decoder.get_input_embeddings()(input_ids)], dim=1)
|
228 |
|
229 |
-
decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids
|
230 |
decoder_token_type_ids = torch.cat(
|
231 |
[
|
232 |
torch.full(
|
@@ -255,7 +254,7 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
255 |
decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
|
256 |
|
257 |
# Always place token_ids_to_token_type_ids_past before input_ids = input_ids[:, remove_prefix_length:]:
|
258 |
-
decoder_token_type_ids = self.token_ids_to_token_type_ids_past(input_ids
|
259 |
decoder_position_ids = decoder_position_ids[:, -1:]
|
260 |
|
261 |
past_length = past_key_values[0][0].shape[2]
|
@@ -282,13 +281,12 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
282 |
)
|
283 |
return input_dict
|
284 |
|
285 |
-
def token_ids_to_token_type_ids(self, token_ids
|
286 |
"""
|
287 |
Extract token type identifiers from the token identifiers.
|
288 |
|
289 |
Argument/s:
|
290 |
token_ids - token identifiers.
|
291 |
-
special_token_ids - special token identifiers that indicate the separation between sections.
|
292 |
token_type_id_section - token type identifier for each section.
|
293 |
|
294 |
Returns:
|
@@ -298,7 +296,7 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
298 |
mbatch_size, seq_len = token_ids.shape
|
299 |
token_type_ids = torch.full_like(token_ids, self.config.section_ids[0], dtype=torch.long, device=token_ids.device)
|
300 |
|
301 |
-
for i, j in enumerate(
|
302 |
# Find first occurrence of special tokens that indicate the boundary between sections:
|
303 |
cols = (token_ids == j).int().argmax(dim=1)
|
304 |
rows = torch.arange(mbatch_size, device=token_ids.device)
|
@@ -323,14 +321,13 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
323 |
|
324 |
return token_type_ids
|
325 |
|
326 |
-
def token_ids_to_token_type_ids_past(self, token_ids
|
327 |
"""
|
328 |
Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
|
329 |
token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
|
330 |
|
331 |
Argument/s:
|
332 |
token_ids - token identifiers.
|
333 |
-
special_token_ids - special token identifiers that indicate the separation between sections.
|
334 |
|
335 |
Returns:
|
336 |
token_type_ids - token type identifiers.
|
@@ -341,7 +338,7 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
341 |
# https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
|
342 |
token_ids = token_ids[:, :-1]
|
343 |
|
344 |
-
for i, j in enumerate(
|
345 |
|
346 |
# Find first occurrence of special token, which indicates the boundary between sections:
|
347 |
exists = torch.any(token_ids == j, dim=1, keepdim=True)
|
@@ -445,13 +442,12 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
445 |
|
446 |
return batch_dict
|
447 |
|
448 |
-
def split_and_decode_sections(self, token_ids,
|
449 |
"""
|
450 |
Split the token identifiers into sections, then convert the token identifiers into strings.
|
451 |
|
452 |
Argument/s:
|
453 |
token_ids - token identifiers.
|
454 |
-
special_token_ids - special token identifiers that indicate the end of each section.
|
455 |
tokenizer - Hugging Face tokenizer.
|
456 |
|
457 |
Returns:
|
@@ -460,14 +456,14 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
460 |
|
461 |
_, seq_len = token_ids.shape
|
462 |
|
463 |
-
# The number of sections is the same as the number of
|
464 |
-
num_sections = len(
|
465 |
|
466 |
sections = {k: [] for k in range(num_sections)}
|
467 |
|
468 |
for i in token_ids:
|
469 |
prev_col = 0
|
470 |
-
for j, k in enumerate(
|
471 |
|
472 |
# The maximum sequence length was exceeded, thus no more tokens:
|
473 |
if prev_col >= seq_len:
|
|
|
200 |
def prepare_inputs_for_generation(
|
201 |
self,
|
202 |
input_ids,
|
|
|
203 |
past_key_values=None,
|
204 |
use_cache=None,
|
205 |
encoder_outputs=None,
|
|
|
225 |
# `inputs_embeds` are only to be used in the 1st generation step:
|
226 |
inputs_embeds = torch.cat([encoder_outputs[0], self.decoder.get_input_embeddings()(input_ids)], dim=1)
|
227 |
|
228 |
+
decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids)
|
229 |
decoder_token_type_ids = torch.cat(
|
230 |
[
|
231 |
torch.full(
|
|
|
254 |
decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
|
255 |
|
256 |
# Always place token_ids_to_token_type_ids_past before input_ids = input_ids[:, remove_prefix_length:]:
|
257 |
+
decoder_token_type_ids = self.token_ids_to_token_type_ids_past(input_ids)
|
258 |
decoder_position_ids = decoder_position_ids[:, -1:]
|
259 |
|
260 |
past_length = past_key_values[0][0].shape[2]
|
|
|
281 |
)
|
282 |
return input_dict
|
283 |
|
284 |
+
def token_ids_to_token_type_ids(self, token_ids):
|
285 |
"""
|
286 |
Extract token type identifiers from the token identifiers.
|
287 |
|
288 |
Argument/s:
|
289 |
token_ids - token identifiers.
|
|
|
290 |
token_type_id_section - token type identifier for each section.
|
291 |
|
292 |
Returns:
|
|
|
296 |
mbatch_size, seq_len = token_ids.shape
|
297 |
token_type_ids = torch.full_like(token_ids, self.config.section_ids[0], dtype=torch.long, device=token_ids.device)
|
298 |
|
299 |
+
for i, j in enumerate(self.config.decoder.separator_token_ids):
|
300 |
# Find first occurrence of special tokens that indicate the boundary between sections:
|
301 |
cols = (token_ids == j).int().argmax(dim=1)
|
302 |
rows = torch.arange(mbatch_size, device=token_ids.device)
|
|
|
321 |
|
322 |
return token_type_ids
|
323 |
|
324 |
+
def token_ids_to_token_type_ids_past(self, token_ids):
|
325 |
"""
|
326 |
Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
|
327 |
token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
|
328 |
|
329 |
Argument/s:
|
330 |
token_ids - token identifiers.
|
|
|
331 |
|
332 |
Returns:
|
333 |
token_type_ids - token type identifiers.
|
|
|
338 |
# https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
|
339 |
token_ids = token_ids[:, :-1]
|
340 |
|
341 |
+
for i, j in enumerate(self.config.decoder.separator_token_ids):
|
342 |
|
343 |
# Find first occurrence of special token, which indicates the boundary between sections:
|
344 |
exists = torch.any(token_ids == j, dim=1, keepdim=True)
|
|
|
442 |
|
443 |
return batch_dict
|
444 |
|
445 |
+
def split_and_decode_sections(self, token_ids, tokenizer: PreTrainedTokenizerFast):
|
446 |
"""
|
447 |
Split the token identifiers into sections, then convert the token identifiers into strings.
|
448 |
|
449 |
Argument/s:
|
450 |
token_ids - token identifiers.
|
|
|
451 |
tokenizer - Hugging Face tokenizer.
|
452 |
|
453 |
Returns:
|
|
|
456 |
|
457 |
_, seq_len = token_ids.shape
|
458 |
|
459 |
+
# The number of sections is the same as the number of separator_token_ids:
|
460 |
+
num_sections = len(self.config.decoder.separator_token_ids)
|
461 |
|
462 |
sections = {k: [] for k in range(num_sections)}
|
463 |
|
464 |
for i in token_ids:
|
465 |
prev_col = 0
|
466 |
+
for j, k in enumerate(self.config.decoder.separator_token_ids):
|
467 |
|
468 |
# The maximum sequence length was exceeded, thus no more tokens:
|
469 |
if prev_col >= seq_len:
|