edugp commited on
Commit
98c2b8e
1 Parent(s): a77e1f7

Update downloading and training scripts

Browse files
Files changed (3) hide show
  1. prepare_wit.py +34 -16
  2. run_hybrid_clip.py +0 -1
  3. run_hybrid_clip.py +567 -0
prepare_wit.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  import logging
4
  import os
5
  import time
 
6
  import urllib.request
7
  import urllib.error
8
 
@@ -10,15 +11,38 @@ import pandas as pd
10
  from tqdm import tqdm
11
 
12
 
 
 
 
 
 
13
  logger = logging.getLogger(__name__)
14
 
15
- def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=1.0, retries: int=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  os.makedirs(output_dir, exist_ok=True)
 
17
  df = pd.read_csv(tsv, sep="\t", engine="python")
18
  df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull())]
19
  # Shuffle
20
  df = df.sample(frac=1.0, random_state=seed)
 
21
  lines = []
 
22
  try:
23
  with tqdm(total=len(df)) as pbar:
24
  for i, row in tqdm(df.iterrows()):
@@ -32,27 +56,20 @@ def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_propo
32
  # Download file
33
  urllib.request.urlretrieve(url, image_path)
34
  lines.append(json.dumps({"image_path": image_path, "captions": [caption]}, ensure_ascii=False))
 
35
  break
36
  except urllib.error.HTTPError as e:
37
- time.sleep(pause)
 
 
 
 
38
  if retry == retries:
39
  raise ValueError("Rate limit achieved:", e)
40
  pbar.update(1)
41
  # Save existing dataset, even upon failure
42
  finally:
43
- total_lines = len(lines)
44
- train_lines = lines[:int(total_lines * train_proportion)]
45
- valid_lines = lines[int(total_lines * train_proportion):int(total_lines * (train_proportion + valid_proportion))]
46
- test_lines = lines[int(total_lines * (train_proportion + valid_proportion)):]
47
-
48
- with open(f"{output_dir}/train_dataset.json", "w") as f:
49
- f.write("\n".join(train_lines))
50
-
51
- with open(f"{output_dir}/valid_dataset.json", "w") as f:
52
- f.write("\n".join(valid_lines))
53
-
54
- with open(f"{output_dir}/test_dataset.json", "w") as f:
55
- f.write("\n".join(test_lines))
56
 
57
  if __name__ == "__main__":
58
  parser = argparse.ArgumentParser(description = "Download and prepare the WIT dataset")
@@ -62,7 +79,8 @@ if __name__ == "__main__":
62
  parser.add_argument("--random_seed", type=int, default=0)
63
  parser.add_argument("--train_proportion", type=float, default=0.8)
64
  parser.add_argument("--valid_proportion", type=float, default=0.1)
 
65
 
66
  args = parser.parse_args()
67
  assert args.train_proportion + args.valid_proportion < 1.0, "The sum of train_proportion and valid_proportion has to be < 1.0"
68
- prepare_wit(args.tsv, args.language, args.output_dir, args.random_seed, args.train_proportion, args.valid_proportion)
 
3
  import logging
4
  import os
5
  import time
6
+ from typing import List
7
  import urllib.request
8
  import urllib.error
9
 
 
11
  from tqdm import tqdm
12
 
13
 
14
+ logging.basicConfig(
15
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
16
+ datefmt="%m/%d/%Y %H:%M:%S",
17
+ level=logging.INFO,
18
+ )
19
  logger = logging.getLogger(__name__)
20
 
21
+ def split_and_save_datasets(lines: List[str], output_dir: str, train_proportion: float, valid_proportion: float):
22
+ total_lines = len(lines)
23
+ train_lines = lines[:int(total_lines * train_proportion)]
24
+ valid_lines = lines[int(total_lines * train_proportion):int(total_lines * (train_proportion + valid_proportion))]
25
+ test_lines = lines[int(total_lines * (train_proportion + valid_proportion)):]
26
+
27
+ with open(f"{output_dir}/train_dataset.json", "w") as f:
28
+ f.write("\n".join(train_lines))
29
+
30
+ with open(f"{output_dir}/valid_dataset.json", "w") as f:
31
+ f.write("\n".join(valid_lines))
32
+
33
+ with open(f"{output_dir}/test_dataset.json", "w") as f:
34
+ f.write("\n".join(test_lines))
35
+
36
+ def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, backup_period: int, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=0.1, retries: int=5):
37
  os.makedirs(output_dir, exist_ok=True)
38
+ logger.info("Loading dataset")
39
  df = pd.read_csv(tsv, sep="\t", engine="python")
40
  df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull())]
41
  # Shuffle
42
  df = df.sample(frac=1.0, random_state=seed)
43
+ logger.info("Download started")
44
  lines = []
45
+ count = 0
46
  try:
47
  with tqdm(total=len(df)) as pbar:
48
  for i, row in tqdm(df.iterrows()):
 
56
  # Download file
57
  urllib.request.urlretrieve(url, image_path)
58
  lines.append(json.dumps({"image_path": image_path, "captions": [caption]}, ensure_ascii=False))
59
+ count += 1
60
  break
61
  except urllib.error.HTTPError as e:
62
+ # time.sleep(pause)
63
+ pass
64
+ if count % backup_period == 0:
65
+ logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
66
+ split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
67
  if retry == retries:
68
  raise ValueError("Rate limit achieved:", e)
69
  pbar.update(1)
70
  # Save existing dataset, even upon failure
71
  finally:
72
+ split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  if __name__ == "__main__":
75
  parser = argparse.ArgumentParser(description = "Download and prepare the WIT dataset")
 
79
  parser.add_argument("--random_seed", type=int, default=0)
80
  parser.add_argument("--train_proportion", type=float, default=0.8)
81
  parser.add_argument("--valid_proportion", type=float, default=0.1)
82
+ parser.add_argument("--backup_period", type=int, default=1000)
83
 
84
  args = parser.parse_args()
85
  assert args.train_proportion + args.valid_proportion < 1.0, "The sum of train_proportion and valid_proportion has to be < 1.0"
86
+ prepare_wit(args.tsv, args.language, args.output_dir, args.random_seed, args.train_proportion, args.valid_proportion, args.backup_period)
run_hybrid_clip.py DELETED
@@ -1 +0,0 @@
1
- /home/eduardogonzalezponferrada/transformers/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py
 
 
run_hybrid_clip.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training a CLIP like dual encoder models using text and vision encoders in the library.
18
+
19
+ The script can be used to train CLIP like models for languages other than english by using
20
+ a text encoder pre-trained in the desired language. Currently this script support the following vision
21
+ and text models:
22
+ Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
23
+ Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
24
+ """
25
+
26
+ import json
27
+ import logging
28
+ import os
29
+ import sys
30
+ import time
31
+ from dataclasses import dataclass, field
32
+ from pathlib import Path
33
+ from typing import Callable, Optional
34
+
35
+ import numpy as np
36
+ import torch
37
+ from torchvision.datasets import VisionDataset
38
+ from torchvision.io import ImageReadMode, read_image
39
+ from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
40
+ from torchvision.transforms.functional import InterpolationMode
41
+ from tqdm import tqdm
42
+
43
+ import jax
44
+ import jax.numpy as jnp
45
+ import optax
46
+ import transformers
47
+ from flax import jax_utils
48
+ from flax.jax_utils import unreplicate
49
+ from flax.training import train_state
50
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
51
+ from modeling_hybrid_clip import FlaxHybridCLIP
52
+ from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
53
+
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+ # Cache the result
58
+ has_tensorboard = is_tensorboard_available()
59
+ if has_tensorboard:
60
+ try:
61
+ from flax.metrics.tensorboard import SummaryWriter
62
+ except ImportError as ie:
63
+ has_tensorboard = False
64
+ print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
65
+
66
+ else:
67
+ print(
68
+ "Unable to display metrics through TensorBoard because the package is not installed: "
69
+ "Please run pip install tensorboard to enable."
70
+ )
71
+
72
+
73
+ @dataclass
74
+ class ModelArguments:
75
+ """
76
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
77
+ """
78
+
79
+ text_model_name_or_path: str = field(
80
+ metadata={
81
+ "help": "The text model checkpoint for weights initialization."
82
+ "Don't set if you want to train a model from scratch."
83
+ },
84
+ )
85
+ vision_model_name_or_path: str = field(
86
+ metadata={
87
+ "help": "The vision model checkpoint for weights initialization."
88
+ "Don't set if you want to train a model from scratch."
89
+ },
90
+ )
91
+ from_pt: bool = field(
92
+ default=True,
93
+ metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
94
+ )
95
+ config_name: Optional[str] = field(
96
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
97
+ )
98
+ tokenizer_name: Optional[str] = field(
99
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
100
+ )
101
+ cache_dir: Optional[str] = field(
102
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
103
+ )
104
+ use_fast_tokenizer: bool = field(
105
+ default=True,
106
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
107
+ )
108
+ dtype: Optional[str] = field(
109
+ default="float32",
110
+ metadata={
111
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
112
+ },
113
+ )
114
+
115
+
116
+ @dataclass
117
+ class DataTrainingArguments:
118
+ """
119
+ Arguments pertaining to what data we are going to input our model for training and eval.
120
+ """
121
+
122
+ data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
123
+ train_file: Optional[str] = field(
124
+ default=None, metadata={"help": "The input training data file (a jsonlines file)."}
125
+ )
126
+ validation_file: Optional[str] = field(
127
+ default=None,
128
+ metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
129
+ )
130
+ max_seq_length: Optional[int] = field(
131
+ default=72,
132
+ metadata={
133
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
134
+ "than this will be truncated, sequences shorter will be padded."
135
+ },
136
+ )
137
+ max_train_samples: Optional[int] = field(
138
+ default=None,
139
+ metadata={
140
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
141
+ "value if set."
142
+ },
143
+ )
144
+ max_eval_samples: Optional[int] = field(
145
+ default=None,
146
+ metadata={
147
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
148
+ "value if set."
149
+ },
150
+ )
151
+ overwrite_cache: bool = field(
152
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
153
+ )
154
+ overwrite_cache: bool = field(
155
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
156
+ )
157
+ preprocessing_num_workers: Optional[int] = field(
158
+ default=None,
159
+ metadata={"help": "The number of processes to use for the preprocessing."},
160
+ )
161
+
162
+ def __post_init__(self):
163
+ if self.train_file is None and self.validation_file is None:
164
+ raise ValueError("Need either a dataset name or a training/validation file.")
165
+ else:
166
+ if self.train_file is not None:
167
+ extension = self.train_file.split(".")[-1]
168
+ assert extension == "json", "`train_file` should be a json file."
169
+ if self.validation_file is not None:
170
+ extension = self.validation_file.split(".")[-1]
171
+ assert extension == "json", "`validation_file` should be a json file."
172
+
173
+
174
+ # We use torchvision for faster image pre-processing.
175
+ # We need to ensure faster processing speed as it can become a bottleneck on TPU
176
+ class Transform(torch.nn.Module):
177
+ def __init__(self, image_size):
178
+ super().__init__()
179
+ self.transforms = torch.nn.Sequential(
180
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
181
+ CenterCrop(image_size),
182
+ ConvertImageDtype(torch.float),
183
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
184
+ )
185
+
186
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
187
+ with torch.no_grad():
188
+ x = self.transforms(x)
189
+ return x
190
+
191
+
192
+ class ImageTextDataset(VisionDataset):
193
+ """
194
+ Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
195
+
196
+ Args:
197
+ root: (string): The root path where the dataset is stored
198
+ file_path: (string): Path to the file containing the image_paths and associated captions.
199
+ The expected format is jsonlines where each line is a json object containing to keys.
200
+ `image_path`: The path to the image.
201
+ `captions`: An `array` of captions.
202
+ transform (callable, optional): A function/transform that takes in an PIL image
203
+ and returns a transformed version. E.g, ``transforms.ToTensor``
204
+ target_transform (callable, optional): A function/transform that takes in the
205
+ target and transforms it.
206
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
207
+ and returns a transformed version.
208
+ """
209
+
210
+ def __init__(
211
+ self,
212
+ root: str,
213
+ file_path: str,
214
+ captions_per_image=2,
215
+ transform: Optional[Callable] = None,
216
+ target_transform: Optional[Callable] = None,
217
+ transforms: Optional[Callable] = None,
218
+ ):
219
+ super().__init__(root, transforms, transform, target_transform)
220
+
221
+ with open(file_path, "r") as f:
222
+ examples = [json.loads(line) for line in f.readlines()]
223
+
224
+ self.captions = []
225
+ self.image_paths = []
226
+
227
+ for example in examples:
228
+ self.captions.extend(example["captions"][:captions_per_image])
229
+ self.image_paths.extend([example["image_path"]] * captions_per_image)
230
+
231
+ def _load_image(self, idx: int):
232
+ path = self.image_paths[idx]
233
+ return read_image(path, mode=ImageReadMode.RGB)
234
+
235
+ def _load_target(self, idx):
236
+ return self.captions[idx]
237
+
238
+ def __getitem__(self, index: int):
239
+ image = self._load_image(index)
240
+ target = self._load_target(index)
241
+
242
+ if self.transforms is not None:
243
+ image, target = self.transforms(image, target)
244
+
245
+ return image, target
246
+
247
+ def __len__(self) -> int:
248
+ return len(self.captions)
249
+
250
+
251
+ class TrainState(train_state.TrainState):
252
+ dropout_rng: jnp.ndarray
253
+
254
+ def replicate(self):
255
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
256
+
257
+
258
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
259
+ summary_writer.scalar("train_time", train_time, step)
260
+
261
+ train_metrics = get_metrics(train_metrics)
262
+ for key, vals in train_metrics.items():
263
+ tag = f"train_{key}"
264
+ for i, val in enumerate(vals):
265
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
266
+
267
+ for metric_name, value in eval_metrics.items():
268
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
269
+
270
+
271
+ def create_learning_rate_fn(
272
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
273
+ ) -> Callable[[int], jnp.array]:
274
+ """Returns a linear warmup, linear_decay learning rate function."""
275
+ steps_per_epoch = train_ds_size // train_batch_size
276
+ num_train_steps = steps_per_epoch * num_train_epochs
277
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
278
+ decay_fn = optax.linear_schedule(
279
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
280
+ )
281
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
282
+ return schedule_fn
283
+
284
+
285
+ def main():
286
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
287
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
288
+ # If we pass only one argument to the script and it's the path to a json file,
289
+ # let's parse it to get our arguments.
290
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
291
+ else:
292
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
293
+
294
+ if (
295
+ os.path.exists(training_args.output_dir)
296
+ and os.listdir(training_args.output_dir)
297
+ and training_args.do_train
298
+ and not training_args.overwrite_output_dir
299
+ ):
300
+ raise ValueError(
301
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
302
+ "Use --overwrite_output_dir to overcome."
303
+ )
304
+
305
+ # Make one log on every process with the configuration for debugging.
306
+ logging.basicConfig(
307
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
308
+ datefmt="%m/%d/%Y %H:%M:%S",
309
+ level=logging.INFO,
310
+ )
311
+ # Setup logging, we only want one process per machine to log things on the screen.
312
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
313
+ if jax.process_index() == 0:
314
+ transformers.utils.logging.set_verbosity_info()
315
+ else:
316
+ transformers.utils.logging.set_verbosity_error()
317
+
318
+ # Set the verbosity to info of the Transformers logger (on main process only):
319
+ logger.info(f"Training/evaluation parameters {training_args}")
320
+
321
+ if model_args.tokenizer_name:
322
+ tokenizer = AutoTokenizer.from_pretrained(
323
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
324
+ )
325
+ elif model_args.text_model_name_or_path:
326
+ tokenizer = AutoTokenizer.from_pretrained(
327
+ model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
328
+ )
329
+ else:
330
+ raise ValueError(
331
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
332
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
333
+ )
334
+
335
+ model = FlaxHybridCLIP.from_text_vision_pretrained(
336
+ model_args.text_model_name_or_path,
337
+ model_args.vision_model_name_or_path,
338
+ seed=training_args.seed,
339
+ dtype=getattr(jnp, model_args.dtype),
340
+ text_from_pt=model_args.from_pt,
341
+ vision_from_pt=model_args.from_pt,
342
+ )
343
+ config = model.config
344
+ # set seed for torch dataloaders
345
+ set_seed(training_args.seed)
346
+
347
+ # Initialize torchvision transforms and jit them for faster processing
348
+ preprocess = Transform(config.vision_config.image_size)
349
+ preprocess = torch.jit.script(preprocess)
350
+
351
+ # Initialize the image-text dataset
352
+ train_dataset = ImageTextDataset(
353
+ data_args.data_dir,
354
+ data_args.train_file,
355
+ captions_per_image=2,
356
+ transform=preprocess,
357
+ )
358
+
359
+ eval_dataset = ImageTextDataset(
360
+ data_args.data_dir,
361
+ data_args.validation_file,
362
+ captions_per_image=1,
363
+ transform=preprocess,
364
+ )
365
+
366
+ # Store some constant
367
+ num_epochs = int(training_args.num_train_epochs)
368
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
369
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
370
+ steps_per_epoch = len(train_dataset) // train_batch_size
371
+ total_train_steps = steps_per_epoch * num_epochs
372
+
373
+ # Use collate function to tokenizer the text and convert the processed images to numpy
374
+ def collate_fn(examples):
375
+ pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
376
+ captions = [example[1] for example in examples]
377
+ inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True, return_tensors="np")
378
+
379
+ batch = {
380
+ "pixel_values": pixel_values,
381
+ "input_ids": inputs["input_ids"],
382
+ "attention_mask": inputs["attention_mask"],
383
+ }
384
+
385
+ return batch
386
+
387
+ # Create data loaders
388
+ train_loader = torch.utils.data.DataLoader(
389
+ train_dataset,
390
+ batch_size=train_batch_size,
391
+ shuffle=True,
392
+ num_workers=data_args.preprocessing_num_workers,
393
+ persistent_workers=True,
394
+ drop_last=True,
395
+ collate_fn=collate_fn,
396
+ )
397
+
398
+ eval_loader = torch.utils.data.DataLoader(
399
+ eval_dataset,
400
+ batch_size=eval_batch_size,
401
+ shuffle=False,
402
+ num_workers=data_args.preprocessing_num_workers,
403
+ persistent_workers=True,
404
+ drop_last=True,
405
+ collate_fn=collate_fn,
406
+ )
407
+
408
+ # Enable tensorboard only on the master node
409
+ if has_tensorboard and jax.process_index() == 0:
410
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
411
+
412
+ # Initialize our training
413
+ rng = jax.random.PRNGKey(training_args.seed)
414
+ rng, dropout_rng = jax.random.split(rng)
415
+
416
+ # Create learning rate schedule
417
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
418
+ len(train_dataset),
419
+ train_batch_size,
420
+ training_args.num_train_epochs,
421
+ training_args.warmup_steps,
422
+ training_args.learning_rate,
423
+ )
424
+
425
+ # create adam optimizer
426
+ adamw = optax.adamw(
427
+ learning_rate=linear_decay_lr_schedule_fn,
428
+ b1=training_args.adam_beta1,
429
+ b2=training_args.adam_beta2,
430
+ eps=training_args.adam_epsilon,
431
+ weight_decay=training_args.weight_decay,
432
+ )
433
+
434
+ # Setup train state
435
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
436
+
437
+ def cross_entropy(logits, axis):
438
+ logprobs = jax.nn.log_softmax(logits, axis=axis)
439
+ nll = jnp.diag(logprobs)
440
+ ce = -jnp.mean(nll)
441
+ return ce
442
+
443
+ def clip_loss(similarity):
444
+ loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2
445
+ return loss
446
+
447
+ # Define gradient update step fn
448
+ def train_step(state, batch):
449
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
450
+
451
+ def compute_loss(params):
452
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
453
+ loss = clip_loss(logits)
454
+ return loss
455
+
456
+ grad_fn = jax.value_and_grad(compute_loss)
457
+ loss, grad = grad_fn(state.params)
458
+ grad = jax.lax.pmean(grad, "batch")
459
+
460
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
461
+
462
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
463
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
464
+
465
+ return new_state, metrics
466
+
467
+ # Define eval fn
468
+ def eval_step(params, batch):
469
+ logits = model(**batch, params=params, train=False)[0]
470
+ loss = clip_loss(logits)
471
+
472
+ # summarize metrics
473
+ metrics = {"loss": loss}
474
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
475
+ return metrics
476
+
477
+ # Create parallel version of the train and eval step
478
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
479
+ p_eval_step = jax.pmap(eval_step, "batch")
480
+
481
+ # Replicate the train state on each device
482
+ state = state.replicate()
483
+
484
+ logger.info("***** Running training *****")
485
+ logger.info(f" Num examples = {len(train_dataset)}")
486
+ logger.info(f" Num Epochs = {num_epochs}")
487
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
488
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
489
+ logger.info(f" Total optimization steps = {total_train_steps}")
490
+
491
+ train_time = 0
492
+ # Create sampling rng
493
+ rng, input_rng = jax.random.split(rng)
494
+
495
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
496
+ best_loss = np.inf
497
+ for epoch in epochs:
498
+ # ======================== Training ================================
499
+ train_start = time.time()
500
+
501
+ # Create sampling rng
502
+ rng, input_rng = jax.random.split(rng)
503
+ train_metrics = []
504
+
505
+ steps_per_epoch = len(train_dataset) // train_batch_size
506
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
507
+ # train
508
+ for batch in train_loader:
509
+ batch = shard(batch)
510
+ state, train_metric = p_train_step(state, batch)
511
+ train_metrics.append(train_metric)
512
+
513
+ train_step_progress_bar.update(1)
514
+
515
+ train_time += time.time() - train_start
516
+
517
+ train_metric = unreplicate(train_metric)
518
+
519
+ train_step_progress_bar.close()
520
+ epochs.write(
521
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
522
+ )
523
+
524
+ # ======================== Evaluating ==============================
525
+ eval_metrics = []
526
+ eval_steps = len(eval_dataset) // eval_batch_size
527
+ eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
528
+ for batch in eval_loader:
529
+ # Model forward
530
+ batch = shard(batch)
531
+ metrics = p_eval_step(state.params, batch)
532
+ eval_metrics.append(metrics)
533
+
534
+ eval_step_progress_bar.update(1)
535
+
536
+ # normalize eval metrics
537
+ eval_metrics = get_metrics(eval_metrics)
538
+
539
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
540
+
541
+ # Print metrics and update progress bar
542
+ eval_step_progress_bar.close()
543
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
544
+ epochs.write(desc)
545
+ epochs.desc = desc
546
+
547
+ # Save metrics
548
+ if has_tensorboard and jax.process_index() == 0:
549
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
550
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
551
+
552
+ # save checkpoint after each epoch and push checkpoint to the hub
553
+ if jax.process_index() == 0:
554
+ if eval_metrics["loss"] < best_loss:
555
+ logger.info(f"Saving best model with a loss = {eval_metrics['loss']}")
556
+ params = jax.device_get(unreplicate(state.params))
557
+ model.save_pretrained(
558
+ training_args.output_dir,
559
+ params=params,
560
+ push_to_hub=training_args.push_to_hub,
561
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
562
+ )
563
+ best_loss = eval_metrics["loss"]
564
+
565
+
566
+ if __name__ == "__main__":
567
+ main()