Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import logging | |
import math | |
import os | |
import sys | |
import datasets | |
import numpy as np | |
import torch | |
import transformers | |
from aac_metrics import evaluate | |
from accelerate import Accelerator, DistributedDataParallelKwargs | |
from accelerate.logging import get_logger | |
from accelerate.utils import set_seed | |
from datasets import load_dataset | |
from omegaconf import OmegaConf | |
from torch.utils.data import DataLoader | |
from tqdm.auto import tqdm | |
from transformers import ( | |
AutoTokenizer, | |
BartConfig, | |
get_inverse_sqrt_schedule, | |
get_scheduler, | |
) | |
from data.collator import DataCollatorForEnClapBart | |
from data.preprocess import Preprocessor | |
from modeling.enclap_bart import EnClapBartForConditionalGeneration | |
logger = get_logger(__name__) | |
metric_list = ["meteor", "spider"] | |
def main(): | |
# Load Configuration | |
cfg_path = sys.argv[1] | |
args = OmegaConf.load(cfg_path) | |
# Initialize Logging | |
accelerator_log_kwargs = {} | |
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
if args.with_tracking: | |
accelerator_log_kwargs["log_with"] = args.report_to | |
accelerator_log_kwargs["project_dir"] = args.output_dir | |
# Initialize Accelerator | |
accelerator = Accelerator( | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
split_batches=args.split_batches, | |
kwargs_handlers=[ddp_kwargs], | |
**accelerator_log_kwargs, | |
) | |
# Handle the repository creation | |
if accelerator.is_main_process: | |
if args.output_dir is not None: | |
os.makedirs(args.output_dir, exist_ok=True) | |
with open(os.path.join(args.output_dir, "args.yaml"), "w") as f: | |
OmegaConf.save(args, f) | |
accelerator.wait_for_everyone() | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
file_handler = logging.FileHandler(os.path.join(args.output_dir, "train_log.txt")) | |
logger.logger.addHandler(file_handler) | |
logger.info(accelerator.state, main_process_only=False) | |
if accelerator.is_local_main_process: | |
datasets.utils.logging.set_verbosity_warning() | |
transformers.utils.logging.set_verbosity_warning() | |
else: | |
datasets.utils.logging.set_verbosity_error() | |
transformers.utils.logging.set_verbosity_error() | |
# If passed along, set the training seed now. | |
if args.seed is not None: | |
set_seed(args.seed) | |
# Get the datasets | |
data_files = {} | |
data_files_eval = {} | |
if args.train_file is not None: | |
data_files["train"] = args.train_file | |
if args.validation_file is not None: | |
data_files_eval["validation"] = args.validation_file | |
extension = args.train_file.split(".")[-1] | |
raw_datasets = load_dataset(extension, data_files=data_files) | |
raw_datasets_eval = load_dataset(extension, data_files=data_files_eval) | |
# Load pretrained model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) | |
if args.config_name_or_path is not None: | |
config = BartConfig.from_pretrained(args.config_name_or_path) | |
else: | |
config = None | |
if args.model_name_or_path is not None: | |
if config is None: | |
model = EnClapBartForConditionalGeneration.from_pretrained( | |
args.model_name_or_path | |
) | |
else: | |
model = EnClapBartForConditionalGeneration.from_pretrained( | |
args.model_name_or_path, config=config | |
) | |
else: | |
model = EnClapBartForConditionalGeneration(config=config) | |
# Set the generation config | |
if args.val_max_target_length is None: | |
args.val_max_target_length = args.max_target_length | |
# Set max encodec length based on the shape of the positional encoding | |
max_encodec_length = model.config.max_position_embeddings - 2 | |
label_pad_token_id = ( | |
-100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id | |
) | |
preprocessor = Preprocessor( | |
args.encodec_base_path, | |
args.clap_base_path, | |
tokenizer, | |
model.config.max_position_embeddings, | |
args.encodec_masking_prob, | |
args.encodec_masking_span, | |
label_pad_token_id, | |
model.config.encodec_vocab_size, | |
args.eval_num_captions, | |
) | |
with accelerator.main_process_first(): | |
train_dataset = raw_datasets["train"].map( | |
preprocessor.preprocess_train, | |
num_proc=args.preprocessing_num_workers, | |
load_from_cache_file=not args.overwrite_cache, | |
desc="Running tokenizer on dataset", | |
) | |
train_dataset.set_format( | |
"pt", | |
columns=[ | |
"input_ids", | |
"attention_mask", | |
"clap", | |
"labels", | |
"decoder_attention_mask", | |
], | |
) | |
# Temporarily set max_target_length for validation. | |
eval_dataset = raw_datasets_eval["validation"].map( | |
preprocessor.preprocess_eval, | |
num_proc=args.preprocessing_num_workers, | |
load_from_cache_file=not args.overwrite_cache, | |
desc="Running tokenizer on dataset", | |
) | |
eval_dataset.set_format( | |
"pt", | |
columns=["input_ids", "attention_mask", "clap"], | |
output_all_columns=True, | |
) | |
train_data_collator = DataCollatorForEnClapBart( | |
tokenizer=tokenizer, | |
model=model, | |
return_tensors="pt", | |
label_pad_token_id=label_pad_token_id, | |
max_length=max_encodec_length, | |
encodec_masking_prob=args.encodec_masking_prob, | |
encodec_masking_span=args.encodec_masking_span, | |
) | |
valid_data_collator = DataCollatorForEnClapBart( | |
tokenizer=tokenizer, | |
model=model, | |
return_tensors="pt", | |
label_pad_token_id=label_pad_token_id, | |
max_length=max_encodec_length, | |
) | |
train_dataloader = DataLoader( | |
train_dataset, | |
shuffle=True, | |
collate_fn=train_data_collator, | |
batch_size=args.per_device_train_batch_size, | |
) | |
eval_dataloader = DataLoader( | |
eval_dataset, | |
collate_fn=valid_data_collator, | |
batch_size=args.per_device_eval_batch_size, | |
) | |
# Optimizer | |
# Split weights in two groups, one with weight decay and the other not. | |
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [ | |
p | |
for n, p in model.named_parameters() | |
if not any(nd in n for nd in no_decay) | |
], | |
"weight_decay": args.weight_decay, | |
}, | |
{ | |
"params": [ | |
p | |
for n, p in model.named_parameters() | |
if any(nd in n for nd in no_decay) | |
], | |
"weight_decay": 0.0, | |
}, | |
] | |
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) | |
# Scheduler and math around the number of training steps. | |
overrode_max_train_steps = False | |
num_update_steps_per_epoch = math.ceil( | |
len(train_dataloader) / args.gradient_accumulation_steps | |
) | |
if args.max_train_steps is None: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
overrode_max_train_steps = True | |
if args.lr_scheduler_type == "inverse_sqrt" and hasattr(args, "time_scale"): | |
lr_scheduler = get_inverse_sqrt_schedule( | |
optimizer=optimizer, | |
num_warmup_steps=args.num_warmup_steps, | |
timescale=args.time_scale, | |
) | |
else: | |
lr_scheduler = get_scheduler( | |
name=args.lr_scheduler_type, | |
optimizer=optimizer, | |
num_warmup_steps=args.num_warmup_steps, | |
num_training_steps=args.max_train_steps, | |
) | |
# Prepare everything with our `accelerator`. | |
( | |
model, | |
optimizer, | |
train_dataloader, | |
eval_dataloader, | |
lr_scheduler, | |
) = accelerator.prepare( | |
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler | |
) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = math.ceil( | |
len(train_dataloader) / args.gradient_accumulation_steps | |
) | |
if overrode_max_train_steps: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
# Afterwards we recalculate our number of training epochs | |
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
# Figure out how many steps we should save the Accelerator states | |
checkpointing_steps = args.checkpointing_steps | |
if checkpointing_steps is not None and checkpointing_steps.isdigit(): | |
checkpointing_steps = int(checkpointing_steps) | |
# The trackers initializes automatically on the main process. | |
if args.with_tracking: | |
accelerator.init_trackers(args.logging_dir) | |
# Train! | |
total_batch_size = ( | |
args.per_device_train_batch_size | |
* accelerator.num_processes | |
* args.gradient_accumulation_steps | |
) | |
if args.split_batches: | |
total_batch_size = int(total_batch_size / accelerator.num_processes) | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataset)}") | |
logger.info(f" Num Epochs = {args.num_train_epochs}") | |
logger.info( | |
f" Instantaneous batch size per device = {args.per_device_train_batch_size}" | |
) | |
logger.info( | |
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" | |
) | |
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | |
logger.info(f" Total optimization steps = {args.max_train_steps}") | |
completed_steps = 0 | |
starting_epoch = 0 | |
# Potentially load in the weights and states from a previous save | |
if not args.overwrite_output_dir and os.path.exists( | |
os.path.join(args.output_dir, "checkpoints") | |
): | |
if args.resume_from_checkpoint is not None: | |
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") | |
accelerator.load_state(args.resume_from_checkpoint) | |
path = os.path.basename(args.resume_from_checkpoint) | |
else: | |
# Get the most recent checkpoint | |
dirs = [ | |
f | |
for f in os.scandir(os.path.join(args.output_dir, "checkpoints")) | |
if f.is_dir() | |
] | |
dirs.sort(key=os.path.getctime) | |
path = dirs[ | |
-1 | |
].name # Sorts folders by date modified, most recent checkpoint is the last | |
accelerator.print(f"Resumed from checkpoint: {dirs[-1]}") | |
accelerator.load_state(dirs[-1]) | |
# Extract `epoch_{i}` or `step_{i}` | |
training_difference = os.path.splitext(path)[0] | |
if "epoch" in training_difference: | |
starting_epoch = int(training_difference.replace("epoch_", "")) + 1 | |
resume_step = None | |
completed_steps = starting_epoch * num_update_steps_per_epoch | |
else: | |
# need to multiply `gradient_accumulation_steps` to reflect real steps | |
resume_step = ( | |
int(training_difference.replace("step_", "")) | |
* args.gradient_accumulation_steps | |
) | |
starting_epoch = resume_step // len(train_dataloader) | |
resume_step -= starting_epoch * len(train_dataloader) | |
completed_steps = resume_step // args.gradient_accumulation_stepp | |
# update the progress_bar if load from checkpoint | |
if args.with_tracking: | |
total_loss = 0 | |
logging_loss = 0 | |
before_epoch_loss = 0 | |
if args.encodec_masking_prob > 0: | |
total_encodec_loss = 0 | |
logging_encodec_loss = 0 | |
before_epoch_encodec_loss = 0 | |
for epoch in range(starting_epoch, args.num_train_epochs): | |
model.train() | |
if ( | |
args.resume_from_checkpoint | |
and epoch == starting_epoch | |
and resume_step is not None | |
): | |
# We skip the first `n` batches in the dataloader when resuming from a checkpoint | |
active_dataloader = accelerator.skip_first_batches( | |
train_dataloader, resume_step | |
) | |
else: | |
active_dataloader = train_dataloader | |
logger.info(f"***** Running epoch {epoch} *****") | |
epoch_iterator = tqdm( | |
active_dataloader, | |
desc="Training", | |
disable=not accelerator.is_local_main_process, | |
dynamic_ncols=True, | |
colour="CYAN", | |
) | |
for step, batch in enumerate(epoch_iterator): | |
with accelerator.accumulate(model): | |
outputs = model(**batch) | |
loss = outputs.loss | |
# We keep track of the loss at each epoch | |
if args.with_tracking: | |
total_loss += outputs.lm_loss.item() | |
if args.encodec_masking_prob > 0: | |
if outputs.encodec_loss is not None: | |
total_encodec_loss += outputs.encodec_loss.item() | |
accelerator.backward(loss) | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_( | |
model.parameters(), max_norm=args.max_grad_norm | |
) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if accelerator.sync_gradients: | |
completed_steps += 1 | |
# Add loss information to tqdm | |
epoch_iterator.set_postfix(loss=total_loss / completed_steps) | |
if completed_steps % args.logging_steps == 0: | |
train_log = { | |
"train/learning_rate": lr_scheduler.get_last_lr()[0] | |
} | |
train_log["train/loss"] = ( | |
total_loss - logging_loss | |
) / args.logging_steps | |
logging_loss = total_loss | |
if args.encodec_masking_prob > 0: | |
train_log["train/encodec_loss"] = ( | |
total_encodec_loss - logging_encodec_loss | |
) / args.logging_steps | |
logging_encodec_loss = total_encodec_loss | |
accelerator.log(train_log, step=completed_steps) | |
if isinstance(checkpointing_steps, int): | |
if completed_steps % checkpointing_steps == 0: | |
output_dir = f"step_{completed_steps }" | |
if args.output_dir is not None: | |
output_dir = os.path.join( | |
args.output_dir, "checkpoints", output_dir | |
) | |
accelerator.save_state(output_dir) | |
if completed_steps >= args.max_train_steps: | |
break | |
model.eval() | |
gen_kwargs = { | |
"max_length": args.val_max_target_length, | |
} | |
predictions = [] | |
references = [] | |
eval_iterator = tqdm( | |
eval_dataloader, | |
desc="Validation", | |
disable=not accelerator.is_local_main_process, | |
dynamic_ncols=True, | |
colour="MAGENTA", | |
) | |
for step, batch in enumerate(eval_iterator): | |
# Drop the padded samples of the last batch of dataloader | |
# try: | |
# if accelerator.gradient_state.end_of_dataloader and accelerator.gradient_state.remainder > 0: | |
# batch = batch[:accelerator.gradient_state.remainder] | |
# except: | |
# pass | |
with torch.no_grad(): | |
batch["input_ids"] = batch["input_ids"].cuda() | |
batch["clap"] = batch["clap"].cuda() | |
batch["attention_mask"] = batch["attention_mask"].cuda() | |
batch["eos_mask"] = batch["eos_mask"].cuda() | |
generated_tokens = accelerator.unwrap_model(model).generate( | |
batch["input_ids"], | |
clap=batch["clap"], | |
attention_mask=batch["attention_mask"], | |
eos_mask=batch["eos_mask"], | |
**gen_kwargs, | |
) | |
generated_tokens = accelerator.pad_across_processes( | |
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id | |
) | |
generated_tokens = generated_tokens.cpu().numpy() | |
captions = batch["captions"] | |
if isinstance(generated_tokens, tuple): | |
generated_tokens = generated_tokens[0] | |
decoded_preds = tokenizer.batch_decode( | |
generated_tokens, skip_special_tokens=True | |
) | |
predictions.extend(decoded_preds) | |
references.extend(captions) | |
logger.info("Evaluating predictions...") | |
result = evaluate(predictions, references, metrics=metric_list) | |
# Gather Result | |
result = {k: v.cuda() for k, v in result[0].items()} | |
result = accelerator.gather_for_metrics(result) | |
# Log the average of metrics among the processes | |
if accelerator.num_processes > 1: | |
result = {f"eval/{k}": round(v.mean().item(), 4) for k, v in result.items()} | |
else: | |
result = {f"eval/{k}": round(v.item(), 4) for k, v in result.items()} | |
logger.info(result) | |
if args.with_tracking: | |
result["train/epoch_train_loss"] = (total_loss - before_epoch_loss) / len( | |
train_dataloader | |
) | |
result["train/steps"] = completed_steps | |
before_epoch_loss = total_loss | |
if args.encodec_masking_prob > 0: | |
result["train/epoch_encodec_loss"] = ( | |
total_encodec_loss - before_epoch_encodec_loss | |
) / len(train_dataloader) | |
before_epoch_encodec_loss = total_encodec_loss | |
accelerator.log(result, step=epoch) | |
if args.checkpointing_steps == "epoch": | |
output_dir = f"epoch_{epoch}" | |
if args.output_dir is not None: | |
output_dir = os.path.join(args.output_dir, "checkpoints", output_dir) | |
accelerator.save_state(output_dir) | |
if accelerator.is_main_process: | |
unwrapped_model = accelerator.unwrap_model(model) | |
unwrapped_model.config.save_pretrained(output_dir) | |
if args.output_dir is not None: | |
save_dir = os.path.join(args.output_dir, "final") | |
accelerator.wait_for_everyone() | |
unwrapped_model = accelerator.unwrap_model(model) | |
unwrapped_model.save_pretrained( | |
save_dir, | |
is_main_process=accelerator.is_main_process, | |
save_function=accelerator.save, | |
) | |
if accelerator.is_main_process: | |
tokenizer.save_pretrained(save_dir) | |
if __name__ == "__main__": | |
main() | |