Feature Extraction
Transformers
Safetensors
vision-encoder-decoder
custom_code
anicolson commited on
Commit
c613151
1 Parent(s): 9365e5f

Upload model

Browse files
Files changed (2) hide show
  1. config.json +3 -0
  2. modelling_cxrrg.py +10 -14
config.json CHANGED
@@ -74,6 +74,9 @@
74
  2
75
  ],
76
  "sep_token_id": null,
 
 
 
77
  "suppress_tokens": null,
78
  "task_specific_params": null,
79
  "temperature": 1.0,
 
74
  2
75
  ],
76
  "sep_token_id": null,
77
+ "separator_token_ids": [
78
+ 3
79
+ ],
80
  "suppress_tokens": null,
81
  "task_specific_params": null,
82
  "temperature": 1.0,
modelling_cxrrg.py CHANGED
@@ -200,7 +200,6 @@ class CXRRGModel(VisionEncoderDecoderModel):
200
  def prepare_inputs_for_generation(
201
  self,
202
  input_ids,
203
- special_token_ids,
204
  past_key_values=None,
205
  use_cache=None,
206
  encoder_outputs=None,
@@ -226,7 +225,7 @@ class CXRRGModel(VisionEncoderDecoderModel):
226
  # `inputs_embeds` are only to be used in the 1st generation step:
227
  inputs_embeds = torch.cat([encoder_outputs[0], self.decoder.get_input_embeddings()(input_ids)], dim=1)
228
 
229
- decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids, special_token_ids)
230
  decoder_token_type_ids = torch.cat(
231
  [
232
  torch.full(
@@ -255,7 +254,7 @@ class CXRRGModel(VisionEncoderDecoderModel):
255
  decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
256
 
257
  # Always place token_ids_to_token_type_ids_past before input_ids = input_ids[:, remove_prefix_length:]:
258
- decoder_token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids)
259
  decoder_position_ids = decoder_position_ids[:, -1:]
260
 
261
  past_length = past_key_values[0][0].shape[2]
@@ -282,13 +281,12 @@ class CXRRGModel(VisionEncoderDecoderModel):
282
  )
283
  return input_dict
284
 
285
- def token_ids_to_token_type_ids(self, token_ids, special_token_ids):
286
  """
287
  Extract token type identifiers from the token identifiers.
288
 
289
  Argument/s:
290
  token_ids - token identifiers.
291
- special_token_ids - special token identifiers that indicate the separation between sections.
292
  token_type_id_section - token type identifier for each section.
293
 
294
  Returns:
@@ -298,7 +296,7 @@ class CXRRGModel(VisionEncoderDecoderModel):
298
  mbatch_size, seq_len = token_ids.shape
299
  token_type_ids = torch.full_like(token_ids, self.config.section_ids[0], dtype=torch.long, device=token_ids.device)
300
 
301
- for i, j in enumerate(special_token_ids):
302
  # Find first occurrence of special tokens that indicate the boundary between sections:
303
  cols = (token_ids == j).int().argmax(dim=1)
304
  rows = torch.arange(mbatch_size, device=token_ids.device)
@@ -323,14 +321,13 @@ class CXRRGModel(VisionEncoderDecoderModel):
323
 
324
  return token_type_ids
325
 
326
- def token_ids_to_token_type_ids_past(self, token_ids, special_token_ids):
327
  """
328
  Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
329
  token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
330
 
331
  Argument/s:
332
  token_ids - token identifiers.
333
- special_token_ids - special token identifiers that indicate the separation between sections.
334
 
335
  Returns:
336
  token_type_ids - token type identifiers.
@@ -341,7 +338,7 @@ class CXRRGModel(VisionEncoderDecoderModel):
341
  # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
342
  token_ids = token_ids[:, :-1]
343
 
344
- for i, j in enumerate(special_token_ids):
345
 
346
  # Find first occurrence of special token, which indicates the boundary between sections:
347
  exists = torch.any(token_ids == j, dim=1, keepdim=True)
@@ -445,13 +442,12 @@ class CXRRGModel(VisionEncoderDecoderModel):
445
 
446
  return batch_dict
447
 
448
- def split_and_decode_sections(self, token_ids, special_token_ids, tokenizer: PreTrainedTokenizerFast):
449
  """
450
  Split the token identifiers into sections, then convert the token identifiers into strings.
451
 
452
  Argument/s:
453
  token_ids - token identifiers.
454
- special_token_ids - special token identifiers that indicate the end of each section.
455
  tokenizer - Hugging Face tokenizer.
456
 
457
  Returns:
@@ -460,14 +456,14 @@ class CXRRGModel(VisionEncoderDecoderModel):
460
 
461
  _, seq_len = token_ids.shape
462
 
463
- # The number of sections is the same as the number of special_token_ids:
464
- num_sections = len(special_token_ids)
465
 
466
  sections = {k: [] for k in range(num_sections)}
467
 
468
  for i in token_ids:
469
  prev_col = 0
470
- for j, k in enumerate(special_token_ids):
471
 
472
  # The maximum sequence length was exceeded, thus no more tokens:
473
  if prev_col >= seq_len:
 
200
  def prepare_inputs_for_generation(
201
  self,
202
  input_ids,
 
203
  past_key_values=None,
204
  use_cache=None,
205
  encoder_outputs=None,
 
225
  # `inputs_embeds` are only to be used in the 1st generation step:
226
  inputs_embeds = torch.cat([encoder_outputs[0], self.decoder.get_input_embeddings()(input_ids)], dim=1)
227
 
228
+ decoder_token_type_ids = self.token_ids_to_token_type_ids(input_ids)
229
  decoder_token_type_ids = torch.cat(
230
  [
231
  torch.full(
 
254
  decoder_position_ids.masked_fill_(report_attention_mask == 0, 1)
255
 
256
  # Always place token_ids_to_token_type_ids_past before input_ids = input_ids[:, remove_prefix_length:]:
257
+ decoder_token_type_ids = self.token_ids_to_token_type_ids_past(input_ids)
258
  decoder_position_ids = decoder_position_ids[:, -1:]
259
 
260
  past_length = past_key_values[0][0].shape[2]
 
281
  )
282
  return input_dict
283
 
284
+ def token_ids_to_token_type_ids(self, token_ids):
285
  """
286
  Extract token type identifiers from the token identifiers.
287
 
288
  Argument/s:
289
  token_ids - token identifiers.
 
290
  token_type_id_section - token type identifier for each section.
291
 
292
  Returns:
 
296
  mbatch_size, seq_len = token_ids.shape
297
  token_type_ids = torch.full_like(token_ids, self.config.section_ids[0], dtype=torch.long, device=token_ids.device)
298
 
299
+ for i, j in enumerate(self.config.decoder.separator_token_ids):
300
  # Find first occurrence of special tokens that indicate the boundary between sections:
301
  cols = (token_ids == j).int().argmax(dim=1)
302
  rows = torch.arange(mbatch_size, device=token_ids.device)
 
321
 
322
  return token_type_ids
323
 
324
+ def token_ids_to_token_type_ids_past(self, token_ids):
325
  """
326
  Extract token type identifiers from the token identifiers if past != None. Make sure to input all the
327
  token_ids (e.g., do not input input_ids = input_ids[:, remove_prefix_length:] from prepare_inputs_for_generation).
328
 
329
  Argument/s:
330
  token_ids - token identifiers.
 
331
 
332
  Returns:
333
  token_type_ids - token type identifiers.
 
338
  # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer.create_token_type_ids_from_sequences.example
339
  token_ids = token_ids[:, :-1]
340
 
341
+ for i, j in enumerate(self.config.decoder.separator_token_ids):
342
 
343
  # Find first occurrence of special token, which indicates the boundary between sections:
344
  exists = torch.any(token_ids == j, dim=1, keepdim=True)
 
442
 
443
  return batch_dict
444
 
445
+ def split_and_decode_sections(self, token_ids, tokenizer: PreTrainedTokenizerFast):
446
  """
447
  Split the token identifiers into sections, then convert the token identifiers into strings.
448
 
449
  Argument/s:
450
  token_ids - token identifiers.
 
451
  tokenizer - Hugging Face tokenizer.
452
 
453
  Returns:
 
456
 
457
  _, seq_len = token_ids.shape
458
 
459
+ # The number of sections is the same as the number of separator_token_ids:
460
+ num_sections = len(self.config.decoder.separator_token_ids)
461
 
462
  sections = {k: [] for k in range(num_sections)}
463
 
464
  for i in token_ids:
465
  prev_col = 0
466
+ for j, k in enumerate(self.config.decoder.separator_token_ids):
467
 
468
  # The maximum sequence length was exceeded, thus no more tokens:
469
  if prev_col >= seq_len: