Joshua Lochner commited on
Commit
31d605f
1 Parent(s): b27b0d5

Remove unused code in training script

Browse files
Files changed (1) hide show
  1. src/train.py +18 -120
src/train.py CHANGED
@@ -54,38 +54,7 @@ class DataTrainingArguments:
54
  default=None,
55
  metadata={'help': 'The number of processes to use for the preprocessing.'},
56
  )
57
- # https://github.com/huggingface/transformers/issues/5204
58
- max_source_length: Optional[int] = field(
59
- default=512,
60
- metadata={
61
- 'help': 'The maximum total input sequence length after tokenization. Sequences longer '
62
- 'than this will be truncated, sequences shorter will be padded.'
63
- },
64
- )
65
- max_target_length: Optional[int] = field(
66
- default=512,
67
- metadata={
68
- 'help': 'The maximum total sequence length for target text after tokenization. Sequences longer '
69
- 'than this will be truncated, sequences shorter will be padded.'
70
- },
71
- )
72
- val_max_target_length: Optional[int] = field(
73
- default=None,
74
- metadata={
75
- 'help': 'The maximum total sequence length for validation target text after tokenization. Sequences longer '
76
- 'than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`.'
77
- 'This argument is also used to override the ``max_length`` param of ``model.generate``, which is used '
78
- 'during ``evaluate`` and ``predict``.'
79
- },
80
- )
81
- pad_to_max_length: bool = field(
82
- default=False,
83
- metadata={
84
- 'help': 'Whether to pad all samples to model maximum sentence length. '
85
- 'If False, will pad the samples dynamically when batching to the maximum length in the batch. More '
86
- 'efficient on GPU but very bad for TPU.'
87
- },
88
- )
89
  max_train_samples: Optional[int] = field(
90
  default=None,
91
  metadata={
@@ -104,29 +73,6 @@ class DataTrainingArguments:
104
  'help': 'For debugging purposes or quicker training, truncate the number of prediction examples to this value if set.'
105
  },
106
  )
107
- num_beams: Optional[int] = field(
108
- default=None,
109
- metadata={
110
- 'help': 'Number of beams to use for evaluation. This argument will be passed to ``model.generate``, '
111
- 'which is used during ``evaluate`` and ``predict``.'
112
- },
113
- )
114
- ignore_pad_token_for_loss: bool = field(
115
- default=True,
116
- metadata={
117
- 'help': 'Whether to ignore the tokens corresponding to padded labels in the loss computation or not.'
118
- },
119
- )
120
- source_prefix: Optional[str] = field(
121
- default=CustomTokens.EXTRACT_SEGMENTS_PREFIX.value, metadata={
122
- 'help': 'A prefix to add before every source text (useful for T5 models).'}
123
- )
124
-
125
- # TODO add vectorizer params
126
-
127
- def __post_init__(self):
128
- if self.val_max_target_length is None:
129
- self.val_max_target_length = self.max_target_length
130
 
131
 
132
  @dataclass
@@ -319,12 +265,6 @@ def main():
319
  pickle.dump(vectorizer, fp)
320
 
321
  if not training_args.skip_train_transformer:
322
-
323
- if data_training_args.source_prefix is None and 't5-' in model_args.model_name_or_path:
324
- logger.warning(
325
- "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with `--source_prefix 'summarize: ' `"
326
- )
327
-
328
  # Detecting last checkpoint.
329
  last_checkpoint = None
330
  if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
@@ -338,77 +278,38 @@ def main():
338
  f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
339
  )
340
 
341
- # Load pretrained model and tokenizer
342
- model = AutoModelForSeq2SeqLM.from_pretrained(
343
- model_args.model_name_or_path)
344
- model.to(device())
345
-
346
- tokenizer = AutoTokenizer.from_pretrained(
347
- model_args.model_name_or_path)
348
-
349
- # Ensure model and tokenizer contain the custom tokens
350
- CustomTokens.add_custom_tokens(tokenizer)
351
- model.resize_token_embeddings(len(tokenizer))
352
-
353
- if model.config.decoder_start_token_id is None:
354
- raise ValueError(
355
- 'Make sure that `config.decoder_start_token_id` is correctly defined')
356
-
357
- if hasattr(model.config, 'max_position_embeddings') and model.config.max_position_embeddings < data_training_args.max_source_length:
358
- if model_args.resize_position_embeddings is None:
359
- logger.warning(
360
- f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} to {data_training_args.max_source_length}."
361
- )
362
- model.resize_position_embeddings(
363
- data_training_args.max_source_length)
364
-
365
- elif model_args.resize_position_embeddings:
366
- model.resize_position_embeddings(
367
- data_training_args.max_source_length)
368
-
369
- else:
370
- raise ValueError(
371
- f'`--max_source_length` is set to {data_training_args.max_source_length}, but the model only has {model.config.max_position_embeddings}'
372
- f' position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically '
373
- "resize the model's position encodings by passing `--resize_position_embeddings`."
374
- )
375
 
376
  # Preprocessing the datasets.
377
  # We need to tokenize inputs and targets.
378
  column_names = raw_datasets['train'].column_names
379
 
380
- # Temporarily set max_target_length for training.
381
- max_target_length = data_training_args.max_target_length
382
- padding = 'max_length' if data_training_args.pad_to_max_length else False
383
-
384
- if training_args.label_smoothing_factor > 0 and not hasattr(model, 'prepare_decoder_input_ids_from_labels'):
385
- logger.warning(
386
- 'label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for'
387
- f'`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory'
388
- )
389
 
390
- prefix = data_training_args.source_prefix if data_training_args.source_prefix is not None else ''
391
 
392
  # https://github.com/huggingface/transformers/issues/5204
393
  def preprocess_function(examples):
394
  inputs = examples['text']
395
  targets = examples['extracted']
396
  inputs = [prefix + inp for inp in inputs]
397
- model_inputs = tokenizer(
398
- inputs, max_length=data_training_args.max_source_length, padding=padding, truncation=True)
399
 
400
  # Setup the tokenizer for targets
401
  with tokenizer.as_target_tokenizer():
402
- labels = tokenizer(
403
- targets, max_length=max_target_length, padding=padding, truncation=True)
 
 
404
 
405
- # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
406
- # padding in the loss.
407
- if padding == 'max_length' and data_training_args.ignore_pad_token_for_loss:
408
- labels['input_ids'] = [
409
- [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels['input_ids']
410
- ]
411
- model_inputs['labels'] = labels['input_ids']
412
 
413
  return model_inputs
414
 
@@ -434,7 +335,6 @@ def main():
434
  train_dataset = prepare_dataset(
435
  train_dataset, desc='Running tokenizer on train dataset')
436
 
437
- max_target_length = data_training_args.val_max_target_length
438
  if 'validation' not in raw_datasets:
439
  raise ValueError('Validation dataset missing')
440
  eval_dataset = raw_datasets['validation']
@@ -456,12 +356,10 @@ def main():
456
  predict_dataset, desc='Running tokenizer on prediction dataset')
457
 
458
  # Data collator
459
- label_pad_token_id = - \
460
- 100 if data_training_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
461
  data_collator = DataCollatorForSeq2Seq(
462
  tokenizer,
463
  model=model,
464
- label_pad_token_id=label_pad_token_id,
465
  pad_to_multiple_of=8 if training_args.fp16 else None,
466
  )
467
 
 
54
  default=None,
55
  metadata={'help': 'The number of processes to use for the preprocessing.'},
56
  )
57
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  max_train_samples: Optional[int] = field(
59
  default=None,
60
  metadata={
 
73
  'help': 'For debugging purposes or quicker training, truncate the number of prediction examples to this value if set.'
74
  },
75
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  @dataclass
 
265
  pickle.dump(vectorizer, fp)
266
 
267
  if not training_args.skip_train_transformer:
 
 
 
 
 
 
268
  # Detecting last checkpoint.
269
  last_checkpoint = None
270
  if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
 
278
  f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
279
  )
280
 
281
+ from model import get_model_tokenizer
282
+ model, tokenizer = get_model_tokenizer(
283
+ model_args.model_name_or_path, model_args.cache_dir)
284
+ # max_tokenizer_length = model.config.d_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  # Preprocessing the datasets.
287
  # We need to tokenize inputs and targets.
288
  column_names = raw_datasets['train'].column_names
289
 
290
+ prefix = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value
 
 
 
 
 
 
 
 
291
 
292
+ PAD_TOKEN_REPLACE_ID = -100
293
 
294
  # https://github.com/huggingface/transformers/issues/5204
295
  def preprocess_function(examples):
296
  inputs = examples['text']
297
  targets = examples['extracted']
298
  inputs = [prefix + inp for inp in inputs]
299
+ model_inputs = tokenizer(inputs, truncation=True)
 
300
 
301
  # Setup the tokenizer for targets
302
  with tokenizer.as_target_tokenizer():
303
+ labels = tokenizer(targets, truncation=True)
304
+
305
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100
306
+ # when we want to ignore padding in the loss.
307
 
308
+ model_inputs['labels'] = [
309
+ [(l if l != tokenizer.pad_token_id else PAD_TOKEN_REPLACE_ID)
310
+ for l in label]
311
+ for label in labels['input_ids']
312
+ ]
 
 
313
 
314
  return model_inputs
315
 
 
335
  train_dataset = prepare_dataset(
336
  train_dataset, desc='Running tokenizer on train dataset')
337
 
 
338
  if 'validation' not in raw_datasets:
339
  raise ValueError('Validation dataset missing')
340
  eval_dataset = raw_datasets['validation']
 
356
  predict_dataset, desc='Running tokenizer on prediction dataset')
357
 
358
  # Data collator
 
 
359
  data_collator = DataCollatorForSeq2Seq(
360
  tokenizer,
361
  model=model,
362
+ label_pad_token_id=PAD_TOKEN_REPLACE_ID,
363
  pad_to_multiple_of=8 if training_args.fp16 else None,
364
  )
365