ydshieh
commited on
Commit
•
3bffad7
1
Parent(s):
a1a9885
Use new FlaxVisionEncoderDecoderModel class
Browse files- run_image_caption.py +255 -101
- run_summarization_flax.py +265 -100
run_image_caption.py
CHANGED
@@ -18,11 +18,6 @@ Fine-tuning the library models for summarization.
|
|
18 |
"""
|
19 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
20 |
|
21 |
-
import sys, os
|
22 |
-
|
23 |
-
current_path = os.path.dirname(os.path.abspath(__file__))
|
24 |
-
sys.path.append(current_path)
|
25 |
-
|
26 |
import logging
|
27 |
import os
|
28 |
import sys
|
@@ -48,20 +43,21 @@ from flax import jax_utils, traverse_util
|
|
48 |
from flax.jax_utils import unreplicate
|
49 |
from flax.training import train_state
|
50 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
|
51 |
from transformers import (
|
52 |
CONFIG_MAPPING,
|
53 |
-
|
54 |
AutoConfig,
|
|
|
55 |
AutoTokenizer,
|
56 |
FlaxAutoModelForSeq2SeqLM,
|
57 |
HfArgumentParser,
|
58 |
TrainingArguments,
|
59 |
is_tensorboard_available,
|
|
|
60 |
)
|
61 |
-
from transformers.file_utils import is_offline_mode
|
62 |
|
63 |
-
from transformers import ViTFeatureExtractor, GPT2Tokenizer, GPT2Config
|
64 |
-
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
|
65 |
|
66 |
logger = logging.getLogger(__name__)
|
67 |
|
@@ -76,10 +72,23 @@ except (LookupError, OSError):
|
|
76 |
nltk.download("punkt", quiet=True)
|
77 |
|
78 |
|
79 |
-
MODEL_CONFIG_CLASSES = list(
|
80 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
81 |
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
@dataclass
|
84 |
class ModelArguments:
|
85 |
"""
|
@@ -93,15 +102,46 @@ class ModelArguments:
|
|
93 |
"Don't set if you want to train a model from scratch."
|
94 |
},
|
95 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
model_type: Optional[str] = field(
|
97 |
default=None,
|
98 |
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
99 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
config_name: Optional[str] = field(
|
101 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
102 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
tokenizer_name: Optional[str] = field(
|
104 |
-
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as
|
105 |
)
|
106 |
cache_dir: Optional[str] = field(
|
107 |
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
@@ -130,19 +170,26 @@ class DataTrainingArguments:
|
|
130 |
dataset_config_name: Optional[str] = field(
|
131 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
132 |
)
|
133 |
-
|
|
|
|
|
|
|
134 |
default=None,
|
135 |
-
metadata={"help": "The name of the column in the datasets containing the full
|
136 |
)
|
137 |
-
|
138 |
default=None,
|
139 |
-
metadata={"help": "The name of the column in the datasets containing the
|
140 |
)
|
141 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
142 |
validation_file: Optional[str] = field(
|
143 |
default=None,
|
144 |
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
145 |
)
|
|
|
|
|
|
|
|
|
146 |
max_source_length: Optional[int] = field(
|
147 |
default=1024,
|
148 |
metadata={
|
@@ -191,9 +238,6 @@ class DataTrainingArguments:
|
|
191 |
default=None,
|
192 |
metadata={"help": "The number of processes to use for the preprocessing."},
|
193 |
)
|
194 |
-
source_prefix: Optional[str] = field(
|
195 |
-
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
196 |
-
)
|
197 |
predict_with_generate: bool = field(
|
198 |
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
199 |
)
|
@@ -222,18 +266,8 @@ class DataTrainingArguments:
|
|
222 |
self.val_max_target_length = self.max_target_length
|
223 |
|
224 |
|
225 |
-
|
226 |
-
"
|
227 |
-
"big_patent": ("description", "abstract"),
|
228 |
-
"cnn_dailymail": ("article", "highlights"),
|
229 |
-
"orange_sum": ("text", "summary"),
|
230 |
-
"pn_summary": ("article", "summary"),
|
231 |
-
"psc": ("extract_text", "summary_text"),
|
232 |
-
"samsum": ("dialogue", "summary"),
|
233 |
-
"thaisum": ("body", "summary"),
|
234 |
-
"xglue": ("news_body", "news_title"),
|
235 |
-
"xsum": ("document", "summary"),
|
236 |
-
"wiki_summary": ("article", "highlights"),
|
237 |
}
|
238 |
|
239 |
|
@@ -337,6 +371,16 @@ def main():
|
|
337 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
338 |
logger.info(f"Training/evaluation parameters {training_args}")
|
339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
341 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
342 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
@@ -347,7 +391,7 @@ def main():
|
|
347 |
if data_args.dataset_name is not None:
|
348 |
# Downloading and loading a dataset from the hub.
|
349 |
dataset = load_dataset(
|
350 |
-
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir=
|
351 |
)
|
352 |
else:
|
353 |
data_files = {}
|
@@ -360,38 +404,152 @@ def main():
|
|
360 |
if data_args.test_file is not None:
|
361 |
data_files["test"] = data_args.test_file
|
362 |
extension = data_args.test_file.split(".")[-1]
|
363 |
-
|
|
|
|
|
|
|
364 |
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
-
|
|
|
|
|
377 |
|
378 |
-
|
379 |
-
assert vit_name_path
|
380 |
-
assert gpt2_name_path
|
381 |
-
vit_gpt2_model = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
|
382 |
-
vit_name_path, gpt2_name_path
|
383 |
-
)
|
384 |
else:
|
385 |
-
|
386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
|
396 |
# Preprocessing the datasets.
|
397 |
# We need to tokenize inputs and targets.
|
@@ -405,8 +563,26 @@ def main():
|
|
405 |
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
406 |
return
|
407 |
|
408 |
-
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
|
411 |
# Temporarily set max_target_length for training.
|
412 |
max_target_length = data_args.max_target_length
|
@@ -414,29 +590,25 @@ def main():
|
|
414 |
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
415 |
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
416 |
# for that dynamically import the `shift_tokens_right` function from the model file
|
417 |
-
model_module = __import__(
|
418 |
-
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
419 |
|
420 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
421 |
def preprocess_function(examples):
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
for
|
426 |
-
with Image.open(
|
427 |
try:
|
428 |
encoder_inputs = feature_extractor(images=image, return_tensors="np")
|
429 |
except:
|
430 |
continue
|
431 |
-
|
432 |
-
|
433 |
-
_captions.append(z + ' ' + tokenizer.eos_token)
|
434 |
-
pixel_values = np.concatenate(_pixel_values)
|
435 |
|
436 |
-
|
437 |
-
|
438 |
-
# Add eos_token!!
|
439 |
-
#targets = [x + ' ' + tokenizer.eos_token for x in targets]
|
440 |
|
441 |
model_inputs = {}
|
442 |
model_inputs['pixel_values'] = pixel_values
|
@@ -448,18 +620,13 @@ def main():
|
|
448 |
)
|
449 |
|
450 |
model_inputs["labels"] = labels["input_ids"]
|
451 |
-
|
452 |
-
#print(labels["input_ids"])
|
453 |
-
#print(gpt2_config.pad_token_id)
|
454 |
-
#rint(gpt2_config.bos_token_id)
|
455 |
-
|
456 |
decoder_input_ids = shift_tokens_right_fn(
|
457 |
-
jnp.array(labels["input_ids"]),
|
458 |
)
|
459 |
-
model_inputs["
|
460 |
|
461 |
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
462 |
-
model_inputs["
|
463 |
|
464 |
return model_inputs
|
465 |
|
@@ -469,7 +636,6 @@ def main():
|
|
469 |
train_dataset = dataset["train"]
|
470 |
if data_args.max_train_samples is not None:
|
471 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
472 |
-
|
473 |
train_dataset = train_dataset.map(
|
474 |
preprocess_function,
|
475 |
batched=True,
|
@@ -604,7 +770,7 @@ def main():
|
|
604 |
)
|
605 |
|
606 |
# Setup train state
|
607 |
-
state = TrainState.create(apply_fn=
|
608 |
|
609 |
# label smoothed cross entropy
|
610 |
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
@@ -635,7 +801,7 @@ def main():
|
|
635 |
def compute_loss(params):
|
636 |
labels = batch.pop("labels")
|
637 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
638 |
-
loss = loss_fn(logits, labels, batch["
|
639 |
return loss
|
640 |
|
641 |
grad_fn = jax.value_and_grad(compute_loss)
|
@@ -653,7 +819,7 @@ def main():
|
|
653 |
def eval_step(params, batch, label_smoothing_factor=0.0):
|
654 |
labels = batch.pop("labels")
|
655 |
logits = model(**batch, params=params, train=False)[0]
|
656 |
-
loss = loss_fn(logits, labels, batch["
|
657 |
|
658 |
# summarize metrics
|
659 |
metrics = {"loss": loss}
|
@@ -669,15 +835,7 @@ def main():
|
|
669 |
|
670 |
def generate_step(params, batch):
|
671 |
model.params = params
|
672 |
-
# output_ids = model.generate(batch["pixel_values"], **gen_kwargs)
|
673 |
-
|
674 |
-
#encoder_outputs = model.encode(pixel_values=batch['pixel_values'])
|
675 |
-
#output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], encoder_outputs=encoder_outputs, **gen_kwargs)
|
676 |
-
|
677 |
-
# encoder_outputs = model.encode(pixel_values=batch['pixel_values'], params=params, train=False)
|
678 |
output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
|
679 |
-
|
680 |
-
|
681 |
return output_ids.sequences
|
682 |
|
683 |
# Create parallel version of the train and eval step
|
@@ -727,7 +885,6 @@ def main():
|
|
727 |
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
728 |
fp.write(desc + '\n')
|
729 |
|
730 |
-
|
731 |
# ======================== Evaluating ==============================
|
732 |
eval_metrics = []
|
733 |
eval_preds = []
|
@@ -768,7 +925,6 @@ def main():
|
|
768 |
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
769 |
fp.write(desc + '\n')
|
770 |
|
771 |
-
|
772 |
# Save metrics
|
773 |
if has_tensorboard and jax.process_index() == 0:
|
774 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
@@ -816,17 +972,15 @@ def main():
|
|
816 |
logger.info(desc)
|
817 |
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
818 |
fp.write(desc + '\n')
|
819 |
-
|
820 |
|
821 |
# save checkpoint after each epoch and push checkpoint to the hub
|
822 |
if jax.process_index() == 0:
|
823 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
824 |
-
model.save_pretrained(
|
825 |
-
|
826 |
-
|
827 |
-
push_to_hub=
|
828 |
-
|
829 |
-
)
|
830 |
|
831 |
if __name__ == "__main__":
|
832 |
main()
|
|
|
18 |
"""
|
19 |
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
20 |
|
|
|
|
|
|
|
|
|
|
|
21 |
import logging
|
22 |
import os
|
23 |
import sys
|
|
|
43 |
from flax.jax_utils import unreplicate
|
44 |
from flax.training import train_state
|
45 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
46 |
+
from huggingface_hub import Repository
|
47 |
from transformers import (
|
48 |
CONFIG_MAPPING,
|
49 |
+
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
50 |
AutoConfig,
|
51 |
+
AutoFeatureExtractor,
|
52 |
AutoTokenizer,
|
53 |
FlaxAutoModelForSeq2SeqLM,
|
54 |
HfArgumentParser,
|
55 |
TrainingArguments,
|
56 |
is_tensorboard_available,
|
57 |
+
FlaxAutoModelForVision2Seq,
|
58 |
)
|
59 |
+
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
60 |
|
|
|
|
|
61 |
|
62 |
logger = logging.getLogger(__name__)
|
63 |
|
|
|
72 |
nltk.download("punkt", quiet=True)
|
73 |
|
74 |
|
75 |
+
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys())
|
76 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
77 |
|
78 |
|
79 |
+
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
80 |
+
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
81 |
+
"""
|
82 |
+
Shift input ids one token to the right.
|
83 |
+
"""
|
84 |
+
shifted_input_ids = np.zeros_like(input_ids)
|
85 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
86 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
87 |
+
|
88 |
+
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
89 |
+
return shifted_input_ids
|
90 |
+
|
91 |
+
|
92 |
@dataclass
|
93 |
class ModelArguments:
|
94 |
"""
|
|
|
102 |
"Don't set if you want to train a model from scratch."
|
103 |
},
|
104 |
)
|
105 |
+
encoder_model_name_or_path: Optional[str] = field(
|
106 |
+
default=None,
|
107 |
+
metadata={
|
108 |
+
"help": "The encoder model checkpoint for weights initialization."
|
109 |
+
"Don't set if you want to train a model from scratch."
|
110 |
+
},
|
111 |
+
)
|
112 |
+
decoder_model_name_or_path: Optional[str] = field(
|
113 |
+
default=None,
|
114 |
+
metadata={
|
115 |
+
"help": "The decoder model checkpoint for weights initialization."
|
116 |
+
"Don't set if you want to train a model from scratch."
|
117 |
+
},
|
118 |
+
)
|
119 |
model_type: Optional[str] = field(
|
120 |
default=None,
|
121 |
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
122 |
)
|
123 |
+
encoder_model_type: Optional[str] = field(
|
124 |
+
default=None,
|
125 |
+
metadata={"help": "If training from scratch, pass a encoder model type from the list: " + ", ".join(MODEL_TYPES)},
|
126 |
+
)
|
127 |
+
decoder_model_type: Optional[str] = field(
|
128 |
+
default=None,
|
129 |
+
metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(MODEL_TYPES)},
|
130 |
+
)
|
131 |
config_name: Optional[str] = field(
|
132 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
133 |
)
|
134 |
+
encoder_config_name: Optional[str] = field(
|
135 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as encoder_model_name"}
|
136 |
+
)
|
137 |
+
decoder_config_name: Optional[str] = field(
|
138 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as decoder_model_name"}
|
139 |
+
)
|
140 |
+
feature_extractor_name: Optional[str] = field(
|
141 |
+
default=None, metadata={"help": "Pretrained feature extractor_name name or path if not the same as encoder_model_name"}
|
142 |
+
)
|
143 |
tokenizer_name: Optional[str] = field(
|
144 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as decoder_model_name"}
|
145 |
)
|
146 |
cache_dir: Optional[str] = field(
|
147 |
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
|
|
170 |
dataset_config_name: Optional[str] = field(
|
171 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
172 |
)
|
173 |
+
data_dir: Optional[str] = field(
|
174 |
+
default=None, metadata={"help": "The data directory of the dataset to use (via the datasets library)."}
|
175 |
+
)
|
176 |
+
image_column: Optional[str] = field(
|
177 |
default=None,
|
178 |
+
metadata={"help": "The name of the column in the datasets containing the full image file paths (for image captioning)."},
|
179 |
)
|
180 |
+
caption_column: Optional[str] = field(
|
181 |
default=None,
|
182 |
+
metadata={"help": "The name of the column in the datasets containing the image captions (for image captioning)."},
|
183 |
)
|
184 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
185 |
validation_file: Optional[str] = field(
|
186 |
default=None,
|
187 |
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
188 |
)
|
189 |
+
test_file: Optional[str] = field(
|
190 |
+
default=None,
|
191 |
+
metadata={"help": "An optional input predict data file to do prediction on (a text file)."},
|
192 |
+
)
|
193 |
max_source_length: Optional[int] = field(
|
194 |
default=1024,
|
195 |
metadata={
|
|
|
238 |
default=None,
|
239 |
metadata={"help": "The number of processes to use for the preprocessing."},
|
240 |
)
|
|
|
|
|
|
|
241 |
predict_with_generate: bool = field(
|
242 |
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
243 |
)
|
|
|
266 |
self.val_max_target_length = self.max_target_length
|
267 |
|
268 |
|
269 |
+
image_captioning_name_mapping = {
|
270 |
+
"image_caption_dataset.py": ("image_file", "caption"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
}
|
272 |
|
273 |
|
|
|
371 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
372 |
logger.info(f"Training/evaluation parameters {training_args}")
|
373 |
|
374 |
+
# Handle the repository creation
|
375 |
+
if training_args.push_to_hub:
|
376 |
+
if training_args.hub_model_id is None:
|
377 |
+
repo_name = get_full_repo_name(
|
378 |
+
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
379 |
+
)
|
380 |
+
else:
|
381 |
+
repo_name = training_args.hub_model_id
|
382 |
+
repo = Repository(training_args.output_dir, clone_from=repo_name)
|
383 |
+
|
384 |
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
385 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
386 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
|
|
391 |
if data_args.dataset_name is not None:
|
392 |
# Downloading and loading a dataset from the hub.
|
393 |
dataset = load_dataset(
|
394 |
+
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir=data_args.data_dir
|
395 |
)
|
396 |
else:
|
397 |
data_files = {}
|
|
|
404 |
if data_args.test_file is not None:
|
405 |
data_files["test"] = data_args.test_file
|
406 |
extension = data_args.test_file.split(".")[-1]
|
407 |
+
# TODO: Check
|
408 |
+
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, data_dir=data_args.data_dir)
|
409 |
+
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
410 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
411 |
|
412 |
+
# Load pretrained model and tokenizer
|
413 |
+
|
414 |
+
encoder_cache_dir, decoder_cache_dir = None, None
|
415 |
+
if model_args.cache_dir:
|
416 |
+
encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
|
417 |
+
decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
|
418 |
+
|
419 |
+
if model_args.config_name:
|
420 |
+
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
421 |
+
elif model_args.model_name_or_path:
|
422 |
+
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
423 |
+
elif getattr(CONFIG_MAPPING[model_args.model_type], "from_encoder_decoder_configs", None):
|
424 |
+
|
425 |
+
config_class = CONFIG_MAPPING[model_args.model_type]
|
426 |
|
427 |
+
if model_args.encoder_config_name:
|
428 |
+
encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name, cache_dir=encoder_cache_dir)
|
429 |
+
elif model_args.encoder_model_name_or_path:
|
430 |
+
encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir)
|
431 |
+
else:
|
432 |
+
encoder_config = CONFIG_MAPPING[model_args.encoder_model_type]()
|
433 |
+
logger.warning("You are instantiating a new config instance from scratch for the encoder.")
|
434 |
|
435 |
+
if model_args.decoder_config_name:
|
436 |
+
decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name, cache_dir=decoder_cache_dir)
|
437 |
+
elif model_args.decoder_model_name_or_path:
|
438 |
+
decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir)
|
439 |
+
else:
|
440 |
+
decoder_config = CONFIG_MAPPING[model_args.decoder_model_type]()
|
441 |
+
logger.warning("You are instantiating a new config instance from scratch for the decoder.")
|
442 |
|
443 |
+
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
444 |
+
decoder_config.is_decoder = True
|
445 |
+
decoder_config.add_cross_attention = True
|
446 |
|
447 |
+
config = config_class.from_encoder_decoder_configs(encoder_config, decoder_config)
|
|
|
|
|
|
|
|
|
|
|
448 |
else:
|
449 |
+
config = CONFIG_MAPPING[model_args.model_type]()
|
450 |
+
logger.warning("You are instantiating a new config instance from scratch.")
|
451 |
+
|
452 |
+
decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
|
453 |
+
if not decoder_start_token_id and getattr(config, "decoder", None):
|
454 |
+
decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
|
455 |
+
bos_token_id = getattr(config, "bos_token_id", None)
|
456 |
+
if not bos_token_id and getattr(config, "decoder", None):
|
457 |
+
bos_token_id = getattr(config.decoder, "bos_token_id", None)
|
458 |
+
eos_token_id = getattr(config, "eos_token_id", None)
|
459 |
+
if not eos_token_id and getattr(config, "decoder", None):
|
460 |
+
eos_token_id = getattr(config.decoder, "eos_token_id", None)
|
461 |
+
pad_token_id = getattr(config, "pad_token_id", None)
|
462 |
+
if not pad_token_id and getattr(config, "decoder", None):
|
463 |
+
pad_token_id = getattr(config.decoder, "pad_token_id", None)
|
464 |
+
|
465 |
+
if decoder_start_token_id is None:
|
466 |
+
decoder_start_token_id = bos_token_id
|
467 |
+
if pad_token_id is None:
|
468 |
+
pad_token_id = eos_token_id
|
469 |
+
|
470 |
+
config.decoder_start_token_id = decoder_start_token_id
|
471 |
+
config.bos_token_id = bos_token_id
|
472 |
+
config.eos_token_id = eos_token_id
|
473 |
+
config.pad_token_id = pad_token_id
|
474 |
+
|
475 |
+
if getattr(config, "decoder", None):
|
476 |
+
config.decoder.decoder_start_token_id = decoder_start_token_id
|
477 |
+
config.decoder.bos_token_id = bos_token_id
|
478 |
+
config.decoder.eos_token_id = eos_token_id
|
479 |
+
config.decoder.pad_token_id = pad_token_id
|
480 |
+
|
481 |
+
feature_extractor = None
|
482 |
+
if model_args.feature_extractor_name:
|
483 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
484 |
+
model_args.feature_extractor_name, cache_dir=model_args.cache_dir,
|
485 |
)
|
486 |
+
elif model_args.model_name_or_path:
|
487 |
+
try:
|
488 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
489 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir
|
490 |
+
)
|
491 |
+
except ValueError as e:
|
492 |
+
logger.warning(e)
|
493 |
+
if not feature_extractor:
|
494 |
+
if model_args.encoder_model_name_or_path:
|
495 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
496 |
+
model_args.encoder_model_name_or_path, cache_dir=model_args.cache_dir
|
497 |
+
)
|
498 |
+
else:
|
499 |
+
raise ValueError(
|
500 |
+
"You are instantiating a new feature extractor from scratch. This is not supported by this script."
|
501 |
+
"You can do it from another script, save it, and load it from here, using --feature_extractor_name."
|
502 |
+
)
|
503 |
|
504 |
+
tokenizer = None
|
505 |
+
if model_args.tokenizer_name:
|
506 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
507 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
508 |
+
)
|
509 |
+
elif model_args.model_name_or_path:
|
510 |
+
try:
|
511 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
512 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
513 |
+
)
|
514 |
+
except ValueError as e:
|
515 |
+
logger.warning(e)
|
516 |
+
if not tokenizer:
|
517 |
+
if model_args.decoder_model_name_or_path:
|
518 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
519 |
+
model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
520 |
+
)
|
521 |
+
else:
|
522 |
+
raise ValueError(
|
523 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
524 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
525 |
+
)
|
526 |
+
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
|
527 |
+
|
528 |
+
if model_args.model_name_or_path:
|
529 |
+
model = FlaxAutoModelForVision2Seq.from_pretrained(
|
530 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
531 |
+
)
|
532 |
+
elif model_args.encoder_model_name_or_path and model_args.decoder_model_name_or_path:
|
533 |
+
model_class = FlaxAutoModelForVision2Seq.from_config(config).__class__
|
534 |
+
model = model_class.from_encoder_decoder_pretrained(
|
535 |
+
model_args.encoder_model_name_or_path,
|
536 |
+
model_args.decoder_model_name_or_path,
|
537 |
+
encoder_config=config.encoder,
|
538 |
+
decoder_config=config.decoder,
|
539 |
+
encoder_seed=training_args.seed,
|
540 |
+
decoder_seed=training_args.seed,
|
541 |
+
encoder_dtype=getattr(jnp, model_args.dtype),
|
542 |
+
decoder_dtype=getattr(jnp, model_args.dtype),
|
543 |
+
)
|
544 |
+
# Set `encoder-decoder` (top-level) specific config
|
545 |
+
model.config.decoder_start_token_id = decoder_start_token_id
|
546 |
+
model.config.bos_token_id = bos_token_id
|
547 |
+
model.config.eos_token_id = eos_token_id
|
548 |
+
model.config.pad_token_id = pad_token_id
|
549 |
+
else:
|
550 |
+
model = FlaxAutoModelForVision2Seq.from_config(
|
551 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
552 |
+
)
|
553 |
|
554 |
# Preprocessing the datasets.
|
555 |
# We need to tokenize inputs and targets.
|
|
|
563 |
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
|
564 |
return
|
565 |
|
566 |
+
# Get the column names for input/target.
|
567 |
+
dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None)
|
568 |
+
if data_args.image_column is None:
|
569 |
+
assert dataset_columns is not None
|
570 |
+
image_column = dataset_columns[0]
|
571 |
+
else:
|
572 |
+
image_column = data_args.image_column
|
573 |
+
if image_column not in column_names:
|
574 |
+
raise ValueError(
|
575 |
+
f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
|
576 |
+
)
|
577 |
+
if data_args.caption_column is None:
|
578 |
+
assert dataset_columns is not None
|
579 |
+
caption_column = dataset_columns[1]
|
580 |
+
else:
|
581 |
+
caption_column = data_args.caption_column
|
582 |
+
if caption_column not in column_names:
|
583 |
+
raise ValueError(
|
584 |
+
f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
585 |
+
)
|
586 |
|
587 |
# Temporarily set max_target_length for training.
|
588 |
max_target_length = data_args.max_target_length
|
|
|
590 |
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
591 |
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
592 |
# for that dynamically import the `shift_tokens_right` function from the model file
|
593 |
+
model_module = __import__(model.__module__, fromlist=["shift_tokens_right"])
|
594 |
+
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right)
|
595 |
|
596 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
597 |
def preprocess_function(examples):
|
598 |
+
|
599 |
+
pixel_values = []
|
600 |
+
captions = []
|
601 |
+
for image_file, caption in zip(examples[image_column], examples[caption_column]):
|
602 |
+
with Image.open(image_file) as image:
|
603 |
try:
|
604 |
encoder_inputs = feature_extractor(images=image, return_tensors="np")
|
605 |
except:
|
606 |
continue
|
607 |
+
pixel_values.append(encoder_inputs.pixel_values)
|
608 |
+
captions.append(caption + ' ' + tokenizer.eos_token)
|
|
|
|
|
609 |
|
610 |
+
pixel_values = np.concatenate(pixel_values)
|
611 |
+
targets = captions
|
|
|
|
|
612 |
|
613 |
model_inputs = {}
|
614 |
model_inputs['pixel_values'] = pixel_values
|
|
|
620 |
)
|
621 |
|
622 |
model_inputs["labels"] = labels["input_ids"]
|
|
|
|
|
|
|
|
|
|
|
623 |
decoder_input_ids = shift_tokens_right_fn(
|
624 |
+
jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id
|
625 |
)
|
626 |
+
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
627 |
|
628 |
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
629 |
+
model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
630 |
|
631 |
return model_inputs
|
632 |
|
|
|
636 |
train_dataset = dataset["train"]
|
637 |
if data_args.max_train_samples is not None:
|
638 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
|
|
639 |
train_dataset = train_dataset.map(
|
640 |
preprocess_function,
|
641 |
batched=True,
|
|
|
770 |
)
|
771 |
|
772 |
# Setup train state
|
773 |
+
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
774 |
|
775 |
# label smoothed cross entropy
|
776 |
def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
|
|
|
801 |
def compute_loss(params):
|
802 |
labels = batch.pop("labels")
|
803 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
804 |
+
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
805 |
return loss
|
806 |
|
807 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
|
819 |
def eval_step(params, batch, label_smoothing_factor=0.0):
|
820 |
labels = batch.pop("labels")
|
821 |
logits = model(**batch, params=params, train=False)[0]
|
822 |
+
loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor)
|
823 |
|
824 |
# summarize metrics
|
825 |
metrics = {"loss": loss}
|
|
|
835 |
|
836 |
def generate_step(params, batch):
|
837 |
model.params = params
|
|
|
|
|
|
|
|
|
|
|
|
|
838 |
output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
|
|
|
|
|
839 |
return output_ids.sequences
|
840 |
|
841 |
# Create parallel version of the train and eval step
|
|
|
885 |
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
886 |
fp.write(desc + '\n')
|
887 |
|
|
|
888 |
# ======================== Evaluating ==============================
|
889 |
eval_metrics = []
|
890 |
eval_preds = []
|
|
|
925 |
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
926 |
fp.write(desc + '\n')
|
927 |
|
|
|
928 |
# Save metrics
|
929 |
if has_tensorboard and jax.process_index() == 0:
|
930 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
|
|
972 |
logger.info(desc)
|
973 |
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
974 |
fp.write(desc + '\n')
|
|
|
975 |
|
976 |
# save checkpoint after each epoch and push checkpoint to the hub
|
977 |
if jax.process_index() == 0:
|
978 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
979 |
+
model.save_pretrained(os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'), params=params)
|
980 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
981 |
+
if training_args.push_to_hub:
|
982 |
+
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
983 |
+
|
|
|
984 |
|
985 |
if __name__ == "__main__":
|
986 |
main()
|
run_summarization_flax.py
CHANGED
@@ -32,6 +32,7 @@ import nltk # Here to have a nice missing dependency error message early on
|
|
32 |
import numpy as np
|
33 |
from datasets import Dataset, load_dataset, load_metric
|
34 |
from tqdm import tqdm
|
|
|
35 |
|
36 |
import jax
|
37 |
import jax.numpy as jnp
|
@@ -45,13 +46,15 @@ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_ke
|
|
45 |
from huggingface_hub import Repository
|
46 |
from transformers import (
|
47 |
CONFIG_MAPPING,
|
48 |
-
|
49 |
AutoConfig,
|
|
|
50 |
AutoTokenizer,
|
51 |
FlaxAutoModelForSeq2SeqLM,
|
52 |
HfArgumentParser,
|
53 |
TrainingArguments,
|
54 |
is_tensorboard_available,
|
|
|
55 |
)
|
56 |
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
57 |
|
@@ -69,10 +72,23 @@ except (LookupError, OSError):
|
|
69 |
nltk.download("punkt", quiet=True)
|
70 |
|
71 |
|
72 |
-
MODEL_CONFIG_CLASSES = list(
|
73 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
74 |
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
@dataclass
|
77 |
class ModelArguments:
|
78 |
"""
|
@@ -86,15 +102,46 @@ class ModelArguments:
|
|
86 |
"Don't set if you want to train a model from scratch."
|
87 |
},
|
88 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
model_type: Optional[str] = field(
|
90 |
default=None,
|
91 |
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
92 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
config_name: Optional[str] = field(
|
94 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
95 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
tokenizer_name: Optional[str] = field(
|
97 |
-
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as
|
98 |
)
|
99 |
cache_dir: Optional[str] = field(
|
100 |
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
@@ -123,13 +170,16 @@ class DataTrainingArguments:
|
|
123 |
dataset_config_name: Optional[str] = field(
|
124 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
125 |
)
|
126 |
-
|
|
|
|
|
|
|
127 |
default=None,
|
128 |
-
metadata={"help": "The name of the column in the datasets containing the full
|
129 |
)
|
130 |
-
|
131 |
default=None,
|
132 |
-
metadata={"help": "The name of the column in the datasets containing the
|
133 |
)
|
134 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
135 |
validation_file: Optional[str] = field(
|
@@ -188,9 +238,6 @@ class DataTrainingArguments:
|
|
188 |
default=None,
|
189 |
metadata={"help": "The number of processes to use for the preprocessing."},
|
190 |
)
|
191 |
-
source_prefix: Optional[str] = field(
|
192 |
-
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
193 |
-
)
|
194 |
predict_with_generate: bool = field(
|
195 |
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
196 |
)
|
@@ -219,18 +266,8 @@ class DataTrainingArguments:
|
|
219 |
self.val_max_target_length = self.max_target_length
|
220 |
|
221 |
|
222 |
-
|
223 |
-
"
|
224 |
-
"big_patent": ("description", "abstract"),
|
225 |
-
"cnn_dailymail": ("article", "highlights"),
|
226 |
-
"orange_sum": ("text", "summary"),
|
227 |
-
"pn_summary": ("article", "summary"),
|
228 |
-
"psc": ("extract_text", "summary_text"),
|
229 |
-
"samsum": ("dialogue", "summary"),
|
230 |
-
"thaisum": ("body", "summary"),
|
231 |
-
"xglue": ("news_body", "news_title"),
|
232 |
-
"xsum": ("document", "summary"),
|
233 |
-
"wiki_summary": ("article", "highlights"),
|
234 |
}
|
235 |
|
236 |
|
@@ -354,7 +391,7 @@ def main():
|
|
354 |
if data_args.dataset_name is not None:
|
355 |
# Downloading and loading a dataset from the hub.
|
356 |
dataset = load_dataset(
|
357 |
-
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
|
358 |
)
|
359 |
else:
|
360 |
data_files = {}
|
@@ -367,48 +404,153 @@ def main():
|
|
367 |
if data_args.test_file is not None:
|
368 |
data_files["test"] = data_args.test_file
|
369 |
extension = data_args.test_file.split(".")[-1]
|
370 |
-
|
|
|
371 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
372 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
373 |
|
374 |
# Load pretrained model and tokenizer
|
375 |
|
|
|
|
|
|
|
|
|
|
|
376 |
if model_args.config_name:
|
377 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
378 |
elif model_args.model_name_or_path:
|
379 |
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
else:
|
381 |
config = CONFIG_MAPPING[model_args.model_type]()
|
382 |
logger.warning("You are instantiating a new config instance from scratch.")
|
383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
if model_args.tokenizer_name:
|
385 |
tokenizer = AutoTokenizer.from_pretrained(
|
386 |
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
387 |
)
|
388 |
elif model_args.model_name_or_path:
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
|
398 |
if model_args.model_name_or_path:
|
399 |
-
model =
|
400 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
401 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
else:
|
403 |
-
model =
|
404 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
405 |
)
|
406 |
|
407 |
-
if model.config.decoder_start_token_id is None:
|
408 |
-
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
409 |
-
|
410 |
-
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
411 |
-
|
412 |
# Preprocessing the datasets.
|
413 |
# We need to tokenize inputs and targets.
|
414 |
if training_args.do_train:
|
@@ -422,22 +564,24 @@ def main():
|
|
422 |
return
|
423 |
|
424 |
# Get the column names for input/target.
|
425 |
-
dataset_columns =
|
426 |
-
if data_args.
|
427 |
-
|
|
|
428 |
else:
|
429 |
-
|
430 |
-
if
|
431 |
raise ValueError(
|
432 |
-
f"--
|
433 |
)
|
434 |
-
if data_args.
|
435 |
-
|
|
|
436 |
else:
|
437 |
-
|
438 |
-
if
|
439 |
raise ValueError(
|
440 |
-
f"--
|
441 |
)
|
442 |
|
443 |
# Temporarily set max_target_length for training.
|
@@ -446,17 +590,28 @@ def main():
|
|
446 |
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
447 |
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
448 |
# for that dynamically import the `shift_tokens_right` function from the model file
|
449 |
-
model_module = __import__(model.__module__, fromlist=["
|
450 |
-
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
|
451 |
|
452 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
453 |
def preprocess_function(examples):
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
|
461 |
# Setup the tokenizer for targets
|
462 |
with tokenizer.as_target_tokenizer():
|
@@ -680,7 +835,7 @@ def main():
|
|
680 |
|
681 |
def generate_step(params, batch):
|
682 |
model.params = params
|
683 |
-
output_ids = model.generate(batch[
|
684 |
return output_ids.sequences
|
685 |
|
686 |
# Create parallel version of the train and eval step
|
@@ -723,9 +878,12 @@ def main():
|
|
723 |
|
724 |
train_metric = unreplicate(train_metric)
|
725 |
|
726 |
-
|
727 |
-
|
728 |
-
|
|
|
|
|
|
|
729 |
|
730 |
# ======================== Evaluating ==============================
|
731 |
eval_metrics = []
|
@@ -763,55 +921,62 @@ def main():
|
|
763 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
764 |
epochs.write(desc)
|
765 |
epochs.desc = desc
|
|
|
|
|
|
|
766 |
|
767 |
# Save metrics
|
768 |
if has_tensorboard and jax.process_index() == 0:
|
769 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
770 |
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
771 |
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
791 |
if data_args.predict_with_generate:
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
rouge_metrics = compute_metrics(pred_generations, pred_labels)
|
804 |
-
pred_metrics.update(rouge_metrics)
|
805 |
-
rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
|
806 |
-
|
807 |
-
# Print metrics
|
808 |
-
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
809 |
-
logger.info(desc)
|
810 |
|
811 |
# save checkpoint after each epoch and push checkpoint to the hub
|
812 |
if jax.process_index() == 0:
|
813 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
814 |
-
model.save_pretrained(training_args.output_dir, params=params)
|
815 |
tokenizer.save_pretrained(training_args.output_dir)
|
816 |
if training_args.push_to_hub:
|
817 |
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
|
|
32 |
import numpy as np
|
33 |
from datasets import Dataset, load_dataset, load_metric
|
34 |
from tqdm import tqdm
|
35 |
+
from PIL import Image
|
36 |
|
37 |
import jax
|
38 |
import jax.numpy as jnp
|
|
|
46 |
from huggingface_hub import Repository
|
47 |
from transformers import (
|
48 |
CONFIG_MAPPING,
|
49 |
+
FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
50 |
AutoConfig,
|
51 |
+
AutoFeatureExtractor,
|
52 |
AutoTokenizer,
|
53 |
FlaxAutoModelForSeq2SeqLM,
|
54 |
HfArgumentParser,
|
55 |
TrainingArguments,
|
56 |
is_tensorboard_available,
|
57 |
+
FlaxAutoModelForVision2Seq,
|
58 |
)
|
59 |
from transformers.file_utils import get_full_repo_name, is_offline_mode
|
60 |
|
|
|
72 |
nltk.download("punkt", quiet=True)
|
73 |
|
74 |
|
75 |
+
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING.keys())
|
76 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
77 |
|
78 |
|
79 |
+
# Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
|
80 |
+
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
|
81 |
+
"""
|
82 |
+
Shift input ids one token to the right.
|
83 |
+
"""
|
84 |
+
shifted_input_ids = np.zeros_like(input_ids)
|
85 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
86 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
87 |
+
|
88 |
+
shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
89 |
+
return shifted_input_ids
|
90 |
+
|
91 |
+
|
92 |
@dataclass
|
93 |
class ModelArguments:
|
94 |
"""
|
|
|
102 |
"Don't set if you want to train a model from scratch."
|
103 |
},
|
104 |
)
|
105 |
+
encoder_model_name_or_path: Optional[str] = field(
|
106 |
+
default=None,
|
107 |
+
metadata={
|
108 |
+
"help": "The encoder model checkpoint for weights initialization."
|
109 |
+
"Don't set if you want to train a model from scratch."
|
110 |
+
},
|
111 |
+
)
|
112 |
+
decoder_model_name_or_path: Optional[str] = field(
|
113 |
+
default=None,
|
114 |
+
metadata={
|
115 |
+
"help": "The decoder model checkpoint for weights initialization."
|
116 |
+
"Don't set if you want to train a model from scratch."
|
117 |
+
},
|
118 |
+
)
|
119 |
model_type: Optional[str] = field(
|
120 |
default=None,
|
121 |
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
122 |
)
|
123 |
+
encoder_model_type: Optional[str] = field(
|
124 |
+
default=None,
|
125 |
+
metadata={"help": "If training from scratch, pass a encoder model type from the list: " + ", ".join(MODEL_TYPES)},
|
126 |
+
)
|
127 |
+
decoder_model_type: Optional[str] = field(
|
128 |
+
default=None,
|
129 |
+
metadata={"help": "If training from scratch, pass a decoder model type from the list: " + ", ".join(MODEL_TYPES)},
|
130 |
+
)
|
131 |
config_name: Optional[str] = field(
|
132 |
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
133 |
)
|
134 |
+
encoder_config_name: Optional[str] = field(
|
135 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as encoder_model_name"}
|
136 |
+
)
|
137 |
+
decoder_config_name: Optional[str] = field(
|
138 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as decoder_model_name"}
|
139 |
+
)
|
140 |
+
feature_extractor_name: Optional[str] = field(
|
141 |
+
default=None, metadata={"help": "Pretrained feature extractor_name name or path if not the same as encoder_model_name"}
|
142 |
+
)
|
143 |
tokenizer_name: Optional[str] = field(
|
144 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as decoder_model_name"}
|
145 |
)
|
146 |
cache_dir: Optional[str] = field(
|
147 |
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
|
|
170 |
dataset_config_name: Optional[str] = field(
|
171 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
172 |
)
|
173 |
+
data_dir: Optional[str] = field(
|
174 |
+
default=None, metadata={"help": "The data directory of the dataset to use (via the datasets library)."}
|
175 |
+
)
|
176 |
+
image_column: Optional[str] = field(
|
177 |
default=None,
|
178 |
+
metadata={"help": "The name of the column in the datasets containing the full image file paths (for image captioning)."},
|
179 |
)
|
180 |
+
caption_column: Optional[str] = field(
|
181 |
default=None,
|
182 |
+
metadata={"help": "The name of the column in the datasets containing the image captions (for image captioning)."},
|
183 |
)
|
184 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
185 |
validation_file: Optional[str] = field(
|
|
|
238 |
default=None,
|
239 |
metadata={"help": "The number of processes to use for the preprocessing."},
|
240 |
)
|
|
|
|
|
|
|
241 |
predict_with_generate: bool = field(
|
242 |
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
243 |
)
|
|
|
266 |
self.val_max_target_length = self.max_target_length
|
267 |
|
268 |
|
269 |
+
image_captioning_name_mapping = {
|
270 |
+
"image_caption_dataset.py": ("image_file", "caption"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
}
|
272 |
|
273 |
|
|
|
391 |
if data_args.dataset_name is not None:
|
392 |
# Downloading and loading a dataset from the hub.
|
393 |
dataset = load_dataset(
|
394 |
+
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, data_dir=data_args.data_dir
|
395 |
)
|
396 |
else:
|
397 |
data_files = {}
|
|
|
404 |
if data_args.test_file is not None:
|
405 |
data_files["test"] = data_args.test_file
|
406 |
extension = data_args.test_file.split(".")[-1]
|
407 |
+
# TODO: Check
|
408 |
+
dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, data_dir=data_args.data_dir)
|
409 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
410 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
411 |
|
412 |
# Load pretrained model and tokenizer
|
413 |
|
414 |
+
encoder_cache_dir, decoder_cache_dir = None, None
|
415 |
+
if model_args.cache_dir:
|
416 |
+
encoder_cache_dir = os.path.join(model_args.cache_dir, "encoder")
|
417 |
+
decoder_cache_dir = os.path.join(model_args.cache_dir, "decoder")
|
418 |
+
|
419 |
if model_args.config_name:
|
420 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
421 |
elif model_args.model_name_or_path:
|
422 |
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
423 |
+
elif getattr(CONFIG_MAPPING[model_args.model_type], "from_encoder_decoder_configs", None):
|
424 |
+
|
425 |
+
config_class = CONFIG_MAPPING[model_args.model_type]
|
426 |
+
|
427 |
+
if model_args.encoder_config_name:
|
428 |
+
encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name, cache_dir=encoder_cache_dir)
|
429 |
+
elif model_args.encoder_model_name_or_path:
|
430 |
+
encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path, cache_dir=encoder_cache_dir)
|
431 |
+
else:
|
432 |
+
encoder_config = CONFIG_MAPPING[model_args.encoder_model_type]()
|
433 |
+
logger.warning("You are instantiating a new config instance from scratch for the encoder.")
|
434 |
+
|
435 |
+
if model_args.decoder_config_name:
|
436 |
+
decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name, cache_dir=decoder_cache_dir)
|
437 |
+
elif model_args.decoder_model_name_or_path:
|
438 |
+
decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path, cache_dir=decoder_cache_dir)
|
439 |
+
else:
|
440 |
+
decoder_config = CONFIG_MAPPING[model_args.decoder_model_type]()
|
441 |
+
logger.warning("You are instantiating a new config instance from scratch for the decoder.")
|
442 |
+
|
443 |
+
logger.info("Setting `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
|
444 |
+
decoder_config.is_decoder = True
|
445 |
+
decoder_config.add_cross_attention = True
|
446 |
+
|
447 |
+
config = config_class.from_encoder_decoder_configs(encoder_config, decoder_config)
|
448 |
else:
|
449 |
config = CONFIG_MAPPING[model_args.model_type]()
|
450 |
logger.warning("You are instantiating a new config instance from scratch.")
|
451 |
|
452 |
+
decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
|
453 |
+
if not decoder_start_token_id and getattr(config, "decoder", None):
|
454 |
+
decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
|
455 |
+
bos_token_id = getattr(config, "bos_token_id", None)
|
456 |
+
if not bos_token_id and getattr(config, "decoder", None):
|
457 |
+
bos_token_id = getattr(config.decoder, "bos_token_id", None)
|
458 |
+
eos_token_id = getattr(config, "eos_token_id", None)
|
459 |
+
if not eos_token_id and getattr(config, "decoder", None):
|
460 |
+
eos_token_id = getattr(config.decoder, "eos_token_id", None)
|
461 |
+
pad_token_id = getattr(config, "pad_token_id", None)
|
462 |
+
if not pad_token_id and getattr(config, "decoder", None):
|
463 |
+
pad_token_id = getattr(config.decoder, "pad_token_id", None)
|
464 |
+
|
465 |
+
if decoder_start_token_id is None:
|
466 |
+
decoder_start_token_id = bos_token_id
|
467 |
+
if pad_token_id is None:
|
468 |
+
pad_token_id = eos_token_id
|
469 |
+
|
470 |
+
config.decoder_start_token_id = decoder_start_token_id
|
471 |
+
config.bos_token_id = bos_token_id
|
472 |
+
config.eos_token_id = eos_token_id
|
473 |
+
config.pad_token_id = pad_token_id
|
474 |
+
|
475 |
+
if getattr(config, "decoder", None):
|
476 |
+
config.decoder.decoder_start_token_id = decoder_start_token_id
|
477 |
+
config.decoder.bos_token_id = bos_token_id
|
478 |
+
config.decoder.eos_token_id = eos_token_id
|
479 |
+
config.decoder.pad_token_id = pad_token_id
|
480 |
+
|
481 |
+
feature_extractor = None
|
482 |
+
if model_args.feature_extractor_name:
|
483 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
484 |
+
model_args.feature_extractor_name, cache_dir=model_args.cache_dir,
|
485 |
+
)
|
486 |
+
elif model_args.model_name_or_path:
|
487 |
+
try:
|
488 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
489 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir
|
490 |
+
)
|
491 |
+
except ValueError as e:
|
492 |
+
logger.warning(e)
|
493 |
+
if not feature_extractor:
|
494 |
+
if model_args.encoder_model_name_or_path:
|
495 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
496 |
+
model_args.encoder_model_name_or_path, cache_dir=model_args.cache_dir
|
497 |
+
)
|
498 |
+
else:
|
499 |
+
raise ValueError(
|
500 |
+
"You are instantiating a new feature extractor from scratch. This is not supported by this script."
|
501 |
+
"You can do it from another script, save it, and load it from here, using --feature_extractor_name."
|
502 |
+
)
|
503 |
+
|
504 |
+
tokenizer = None
|
505 |
if model_args.tokenizer_name:
|
506 |
tokenizer = AutoTokenizer.from_pretrained(
|
507 |
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
508 |
)
|
509 |
elif model_args.model_name_or_path:
|
510 |
+
try:
|
511 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
512 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
513 |
+
)
|
514 |
+
except ValueError as e:
|
515 |
+
logger.warning(e)
|
516 |
+
if not tokenizer:
|
517 |
+
if model_args.decoder_model_name_or_path:
|
518 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
519 |
+
model_args.decoder_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
520 |
+
)
|
521 |
+
else:
|
522 |
+
raise ValueError(
|
523 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
524 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
525 |
+
)
|
526 |
+
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
|
527 |
|
528 |
if model_args.model_name_or_path:
|
529 |
+
model = FlaxAutoModelForVision2Seq.from_pretrained(
|
530 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
531 |
)
|
532 |
+
elif model_args.encoder_model_name_or_path and model_args.decoder_model_name_or_path:
|
533 |
+
model_class = FlaxAutoModelForVision2Seq.from_config(config).__class__
|
534 |
+
model = model_class.from_encoder_decoder_pretrained(
|
535 |
+
model_args.encoder_model_name_or_path,
|
536 |
+
model_args.decoder_model_name_or_path,
|
537 |
+
encoder_config=config.encoder,
|
538 |
+
decoder_config=config.decoder,
|
539 |
+
encoder_seed=training_args.seed,
|
540 |
+
decoder_seed=training_args.seed,
|
541 |
+
encoder_dtype=getattr(jnp, model_args.dtype),
|
542 |
+
decoder_dtype=getattr(jnp, model_args.dtype),
|
543 |
+
)
|
544 |
+
# Set `encoder-decoder` (top-level) specific config
|
545 |
+
model.config.decoder_start_token_id = decoder_start_token_id
|
546 |
+
model.config.bos_token_id = bos_token_id
|
547 |
+
model.config.eos_token_id = eos_token_id
|
548 |
+
model.config.pad_token_id = pad_token_id
|
549 |
else:
|
550 |
+
model = FlaxAutoModelForVision2Seq.from_config(
|
551 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
552 |
)
|
553 |
|
|
|
|
|
|
|
|
|
|
|
554 |
# Preprocessing the datasets.
|
555 |
# We need to tokenize inputs and targets.
|
556 |
if training_args.do_train:
|
|
|
564 |
return
|
565 |
|
566 |
# Get the column names for input/target.
|
567 |
+
dataset_columns = image_captioning_name_mapping.get(data_args.dataset_name, None)
|
568 |
+
if data_args.image_column is None:
|
569 |
+
assert dataset_columns is not None
|
570 |
+
image_column = dataset_columns[0]
|
571 |
else:
|
572 |
+
image_column = data_args.image_column
|
573 |
+
if image_column not in column_names:
|
574 |
raise ValueError(
|
575 |
+
f"--image_column' value '{data_args.image_column}' needs to be one of: {', '.join(column_names)}"
|
576 |
)
|
577 |
+
if data_args.caption_column is None:
|
578 |
+
assert dataset_columns is not None
|
579 |
+
caption_column = dataset_columns[1]
|
580 |
else:
|
581 |
+
caption_column = data_args.caption_column
|
582 |
+
if caption_column not in column_names:
|
583 |
raise ValueError(
|
584 |
+
f"--caption_column' value '{data_args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
585 |
)
|
586 |
|
587 |
# Temporarily set max_target_length for training.
|
|
|
590 |
# In Flax, for seq2seq models we need to pass `decoder_input_ids`
|
591 |
# as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here
|
592 |
# for that dynamically import the `shift_tokens_right` function from the model file
|
593 |
+
model_module = __import__(model.__module__, fromlist=["shift_tokens_right"])
|
594 |
+
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right", shift_tokens_right)
|
595 |
|
596 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
597 |
def preprocess_function(examples):
|
598 |
+
|
599 |
+
pixel_values = []
|
600 |
+
captions = []
|
601 |
+
for image_file, caption in zip(examples[image_column], examples[caption_column]):
|
602 |
+
with Image.open(image_file) as image:
|
603 |
+
try:
|
604 |
+
encoder_inputs = feature_extractor(images=image, return_tensors="np")
|
605 |
+
except:
|
606 |
+
continue
|
607 |
+
pixel_values.append(encoder_inputs.pixel_values)
|
608 |
+
captions.append(caption + ' ' + tokenizer.eos_token)
|
609 |
+
|
610 |
+
pixel_values = np.concatenate(pixel_values)
|
611 |
+
targets = captions
|
612 |
+
|
613 |
+
model_inputs = {}
|
614 |
+
model_inputs['pixel_values'] = pixel_values
|
615 |
|
616 |
# Setup the tokenizer for targets
|
617 |
with tokenizer.as_target_tokenizer():
|
|
|
835 |
|
836 |
def generate_step(params, batch):
|
837 |
model.params = params
|
838 |
+
output_ids = model.generate(batch['pixel_values'], **gen_kwargs)
|
839 |
return output_ids.sequences
|
840 |
|
841 |
# Create parallel version of the train and eval step
|
|
|
878 |
|
879 |
train_metric = unreplicate(train_metric)
|
880 |
|
881 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
882 |
+
epochs.write(desc)
|
883 |
+
epochs.desc = desc
|
884 |
+
logger.info(desc)
|
885 |
+
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
886 |
+
fp.write(desc + '\n')
|
887 |
|
888 |
# ======================== Evaluating ==============================
|
889 |
eval_metrics = []
|
|
|
921 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
922 |
epochs.write(desc)
|
923 |
epochs.desc = desc
|
924 |
+
logger.info(desc)
|
925 |
+
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
926 |
+
fp.write(desc + '\n')
|
927 |
|
928 |
# Save metrics
|
929 |
if has_tensorboard and jax.process_index() == 0:
|
930 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
931 |
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
932 |
|
933 |
+
# ======================== Prediction loop ==============================
|
934 |
+
if training_args.do_predict:
|
935 |
+
logger.info("*** Predict ***")
|
936 |
+
|
937 |
+
pred_metrics = []
|
938 |
+
pred_generations = []
|
939 |
+
pred_labels = []
|
940 |
+
|
941 |
+
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
|
942 |
+
pred_steps = len(predict_dataset) // eval_batch_size
|
943 |
+
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
|
944 |
+
# Model forward
|
945 |
+
batch = next(pred_loader)
|
946 |
+
labels = batch["labels"]
|
947 |
+
|
948 |
+
metrics = p_eval_step(state.params, batch)
|
949 |
+
pred_metrics.append(metrics)
|
950 |
+
|
951 |
+
# generation
|
952 |
+
if data_args.predict_with_generate:
|
953 |
+
generated_ids = p_generate_step(state.params, batch)
|
954 |
+
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
955 |
+
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
956 |
+
|
957 |
+
# normalize prediction metrics
|
958 |
+
pred_metrics = get_metrics(pred_metrics)
|
959 |
+
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
960 |
+
|
961 |
+
# compute ROUGE metrics
|
962 |
+
rouge_desc = ""
|
963 |
if data_args.predict_with_generate:
|
964 |
+
rouge_metrics = compute_metrics(pred_generations, pred_labels)
|
965 |
+
pred_metrics.update(rouge_metrics)
|
966 |
+
rouge_desc = " ".join([f"Predict {key}: {value} |" for key, value in rouge_metrics.items()])
|
967 |
+
|
968 |
+
# Print metrics
|
969 |
+
desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})"
|
970 |
+
epochs.write(desc)
|
971 |
+
epochs.desc = desc
|
972 |
+
logger.info(desc)
|
973 |
+
with open(os.path.join(training_args.output_dir, f'report.txt'), 'a', encoding='UTF-8') as fp:
|
974 |
+
fp.write(desc + '\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
975 |
|
976 |
# save checkpoint after each epoch and push checkpoint to the hub
|
977 |
if jax.process_index() == 0:
|
978 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
979 |
+
model.save_pretrained(os.path.join(training_args.output_dir, f'ckpt_{epoch+1}'), params=params)
|
980 |
tokenizer.save_pretrained(training_args.output_dir)
|
981 |
if training_args.push_to_hub:
|
982 |
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|