File size: 22,627 Bytes
98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 3fa7433 98c2b8e 4463ade 2daf3c7 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 3fa7433 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e 4463ade 98c2b8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 |
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training a CLIP like dual encoder models using text and vision encoders in the library.
The script can be used to train CLIP like models for languages other than english by using
a text encoder pre-trained in the desired language. Currently this script support the following vision
and text models:
Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
"""
import json
import logging
import os
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import transformers
from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, shard, shard_prng_key
from torchvision.datasets import VisionDataset
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import (CenterCrop, ConvertImageDtype, Normalize,
Resize)
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from transformers import (AutoTokenizer, HfArgumentParser, TrainingArguments,
is_tensorboard_available, set_seed)
from modeling_hybrid_clip import FlaxHybridCLIP
logger = logging.getLogger(__name__)
# Cache the result
has_tensorboard = is_tensorboard_available()
if has_tensorboard:
try:
from flax.metrics.tensorboard import SummaryWriter
except ImportError as ie:
has_tensorboard = False
print(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
print(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
text_model_name_or_path: str = field(
metadata={
"help": "The text model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
vision_model_name_or_path: str = field(
metadata={
"help": "The vision model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
from_pt: bool = field(
default=True,
metadata={
"help": "whether to load the text and vision model using PyTorch checkpoints."
},
)
config_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained config name or path if not the same as model_name"
},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained tokenizer name or path if not the same as model_name"
},
)
cache_dir: Optional[str] = field(
default=None,
metadata={
"help": "Where do you want to store the pretrained models downloaded from s3"
},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
data_dir: Optional[str] = field(
default=None, metadata={"help": "The data directory containing input files."}
)
train_file: Optional[str] = field(
default=None,
metadata={"help": "The input training data file (a jsonlines file)."},
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
)
max_seq_length: Optional[int] = field(
default=72,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets"},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets"},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
def __post_init__(self):
if self.train_file is None and self.validation_file is None:
raise ValueError(
"Need either a dataset name or a training/validation file."
)
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension == "json", "`train_file` should be a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension == "json", "`validation_file` should be a json file."
# We use torchvision for faster image pre-processing.
# We need to ensure faster processing speed as it can become a bottleneck on TPU
class Transform(torch.nn.Module):
def __init__(self, image_size):
super().__init__()
self.transforms = torch.nn.Sequential(
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
ConvertImageDtype(torch.float),
Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
x = self.transforms(x)
return x
class ImageTextDataset(VisionDataset):
"""
Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
Args:
root: (string): The root path where the dataset is stored
file_path: (string): Path to the file containing the image_paths and associated captions.
The expected format is jsonlines where each line is a json object containing to keys.
`image_path`: The path to the image.
`captions`: An `array` of captions.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__(
self,
root: str,
file_path: str,
captions_per_image=1,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super().__init__(root, transforms, transform, target_transform)
with open(file_path, "r") as f:
examples = [json.loads(line) for line in f.readlines()]
self.captions = []
self.image_paths = []
for example in examples:
captions_subset = example["captions"][:captions_per_image]
self.captions.extend(captions_subset)
self.image_paths.extend([example["image_path"]] * len(captions_subset))
def _load_image(self, idx: int):
path = self.image_paths[idx]
return read_image(path, mode=ImageReadMode.RGB)
def _load_target(self, idx):
return self.captions[idx]
def __getitem__(self, index: int):
image = self._load_image(index)
target = self._load_target(index)
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
return len(self.captions)
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(
dropout_rng=shard_prng_key(self.dropout_rng)
)
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def create_learning_rate_fn(
train_ds_size: int,
train_batch_size: int,
num_train_epochs: int,
num_warmup_steps: int,
learning_rate: float,
) -> Callable[[int], jnp.array]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
)
decay_fn = optax.linear_schedule(
init_value=learning_rate,
end_value=0,
transition_steps=num_train_steps - num_warmup_steps,
)
schedule_fn = optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
)
return schedule_fn
def main():
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments)
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
)
elif model_args.text_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_args.text_model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
model = FlaxHybridCLIP.from_text_vision_pretrained(
model_args.text_model_name_or_path,
model_args.vision_model_name_or_path,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
text_from_pt=model_args.from_pt,
vision_from_pt=model_args.from_pt,
)
config = model.config
# set seed for torch dataloaders
set_seed(training_args.seed)
# Initialize torchvision transforms and jit them for faster processing
preprocess = Transform(config.vision_config.image_size)
preprocess = torch.jit.script(preprocess)
# Initialize the image-text dataset
train_dataset = ImageTextDataset(
data_args.data_dir,
data_args.train_file,
captions_per_image=1,
transform=preprocess,
)
eval_dataset = ImageTextDataset(
data_args.data_dir,
data_args.validation_file,
captions_per_image=1,
transform=preprocess,
)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = (
int(training_args.per_device_train_batch_size) * jax.device_count()
)
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
# Use collate function to tokenizer the text and convert the processed images to numpy
def collate_fn(examples):
pixel_values = (
torch.stack([example[0] for example in examples])
.permute(0, 2, 3, 1)
.numpy()
)
captions = [example[1] for example in examples]
inputs = tokenizer(
captions,
max_length=data_args.max_seq_length,
padding="max_length",
truncation=True,
return_tensors="np",
)
batch = {
"pixel_values": pixel_values,
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
}
return batch
# Create data loaders
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
num_workers=data_args.preprocessing_num_workers,
persistent_workers=True,
drop_last=True,
collate_fn=collate_fn,
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=eval_batch_size,
shuffle=False,
num_workers=data_args.preprocessing_num_workers,
persistent_workers=True,
drop_last=True,
collate_fn=collate_fn,
)
# Enable tensorboard only on the master node
if has_tensorboard and jax.process_index() == 0:
summary_writer = SummaryWriter(
log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()
)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
rng, dropout_rng = jax.random.split(rng)
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
training_args.num_train_epochs,
training_args.warmup_steps,
training_args.learning_rate,
)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
)
# Setup train state
state = TrainState.create(
apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng
)
def cross_entropy(logits, axis):
logprobs = jax.nn.log_softmax(logits, axis=axis)
nll = jnp.diag(logprobs)
ce = -jnp.mean(nll)
return ce
def clip_loss(similarity):
loss = (
cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
) / 2
return loss
# Define gradient update step fn
def train_step(state, batch):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
logits = state.apply_fn(
**batch, params=params, dropout_rng=dropout_rng, train=True
)[0]
loss = clip_loss(logits)
return loss
grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {
"loss": loss,
"learning_rate": linear_decay_lr_schedule_fn(state.step),
}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics
# Define eval fn
def eval_step(params, batch):
logits = model(**batch, params=params, train=False)[0]
loss = clip_loss(logits)
# summarize metrics
metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
# Create parallel version of the train and eval step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
p_eval_step = jax.pmap(eval_step, "batch")
# Replicate the train state on each device
state = state.replicate()
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_epochs}")
logger.info(
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
)
logger.info(
f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
)
logger.info(f" Total optimization steps = {total_train_steps}")
train_time = 0
# Create sampling rng
rng, input_rng = jax.random.split(rng)
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
best_loss = np.inf
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, input_rng = jax.random.split(rng)
train_metrics = []
steps_per_epoch = len(train_dataset) // train_batch_size
train_step_progress_bar = tqdm(
total=steps_per_epoch, desc="Training...", position=1, leave=False
)
# train
for batch in train_loader:
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_step_progress_bar.update(1)
train_time += time.time() - train_start
train_metric = unreplicate(train_metric)
train_step_progress_bar.close()
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
eval_metrics = []
eval_steps = len(eval_dataset) // eval_batch_size
eval_step_progress_bar = tqdm(
total=eval_steps, desc="Evaluating...", position=2, leave=False
)
for batch in eval_loader:
# Model forward
batch = shard(batch)
metrics = p_eval_step(state.params, batch)
eval_metrics.append(metrics)
eval_step_progress_bar.update(1)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
# Print metrics and update progress bar
eval_step_progress_bar.close()
desc = (
f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
)
epochs.write(desc)
epochs.desc = desc
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(
summary_writer, train_metrics, eval_metrics, train_time, cur_step
)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
if eval_metrics["loss"] < best_loss:
logger.info(f"Saving best model with a loss = {eval_metrics['loss']}")
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
best_loss = eval_metrics["loss"]
if __name__ == "__main__":
main()
|