#!/usr/bin/env python | |
# coding: utf-8 | |
# # Creating a Zero-Shot classifier based on BETO | |
# | |
# This notebook/script fine-tunes a BETO (spanish bert, 'dccuchile/bert-base-spanish-wwm-cased') model on the spanish XNLI dataset. | |
# The fine-tuned model can then be fed to a Huggingface ZeroShot pipeline to obtain a ZeroShot classifier. | |
# In[ ]: | |
from datasets import load_dataset, Dataset, load_metric, load_from_disk | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from transformers import Trainer, TrainingArguments | |
import torch | |
from pathlib import Path | |
# from ray import tune | |
# from ray.tune.suggest.hyperopt import HyperOptSearch | |
# from ray.tune.schedulers import ASHAScheduler | |
# # Prepare the datasets | |
# In[ ]: | |
xnli_es = load_dataset("xnli", "es") | |
# In[ ]: | |
xnli_es | |
# >joeddav | |
# >Aug '20 | |
# > | |
# >@rsk97 In addition, just make sure the model used is trained on an NLI task and that the **last output label corresponds to entailment** while the **first output label corresponds to contradiction**. | |
# | |
# => We change the original `label` and use the `labels` column, which is required by a `AutoModelForSequenceClassification` | |
# In[ ]: | |
# see markdown above | |
def switch_label_id(row): | |
if row["label"] == 0: | |
return {"labels": 2} | |
elif row["label"] == 2: | |
return {"labels": 0} | |
else: | |
return {"labels": 1} | |
for split in xnli_es: | |
xnli_es[split] = xnli_es[split].map(switch_label_id) | |
# ## Tokenize data | |
# In[ ]: | |
tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased") | |
# In a first attempt i padded all data to the maximum length of the dataset (379). However, the traninig takes substanially longer with all the paddings, it's better to pass in the tokenizer to the `Trainer` and let the `Trainer` do the padding on a batch level. | |
# In[ ]: | |
# Figured out max length of the dataset manually | |
# max_length = 379 | |
def tokenize(row): | |
return tokenizer(row["premise"], row["hypothesis"], truncation=True, max_length=512) #, padding="max_length", max_length=max_length) | |
# In[ ]: | |
data = {} | |
for split in xnli_es: | |
data[split] = xnli_es[split].map( | |
tokenize, | |
remove_columns=["hypothesis", "premise", "label"], | |
batched=True, | |
batch_size=128 | |
) | |
# In[ ]: | |
train_path = str(Path("./train_ds").absolute()) | |
valid_path = str(Path("./valid_ds").absolute()) | |
data["train"].save_to_disk(train_path) | |
data["validation"].save_to_disk(valid_path) | |
# In[ ]: | |
# We can use `datasets.Dataset`s directly | |
# class XnliDataset(torch.utils.data.Dataset): | |
# def __init__(self, data): | |
# self.data = data | |
# def __getitem__(self, idx): | |
# item = {key: torch.tensor(val) for key, val in self.data[idx].items()} | |
# return item | |
# def __len__(self): | |
# return len(self.data) | |
# In[ ]: | |
def trainable(config): | |
metric = load_metric("xnli", "es") | |
def compute_metrics(eval_pred): | |
predictions, labels = eval_pred | |
predictions = predictions.argmax(axis=-1) | |
return metric.compute(predictions=predictions, references=labels) | |
model = AutoModelForSequenceClassification.from_pretrained("dccuchile/bert-base-spanish-wwm-cased", num_labels=3) | |
training_args = TrainingArguments( | |
output_dir='./results', # output directory | |
do_train=True, | |
do_eval=True, | |
evaluation_strategy="steps", | |
eval_steps=500, | |
load_best_model_at_end=True, | |
metric_for_best_model="eval_accuracy", | |
num_train_epochs=config["epochs"], # total number of training epochs | |
per_device_train_batch_size=config["batch_size"], # batch size per device during training | |
per_device_eval_batch_size=config["batch_size_eval"], # batch size for evaluation | |
warmup_steps=config["warmup_steps"], # 500 | |
weight_decay=config["weight_decay"], # 0.001 # strength of weight decay | |
learning_rate=config["learning_rate"], # 5e-05 | |
logging_dir='./logs', # directory for storing logs | |
logging_steps=250, | |
#save_steps=500, # ignored when using load_best_model_at_end | |
save_total_limit=10, | |
no_cuda=False, | |
disable_tqdm=True, | |
) | |
# train_dataset = XnliDataset(load_from_disk(config["train_path"])) | |
# valid_dataset = XnliDataset(load_from_disk(config["valid_path"])) | |
train_dataset = load_from_disk(config["train_path"]) | |
valid_dataset = load_from_disk(config["valid_path"]) | |
trainer = Trainer( | |
model, | |
tokenizer=tokenizer, | |
args=training_args, # training arguments, defined above | |
train_dataset=train_dataset, # training dataset | |
eval_dataset=valid_dataset, # evaluation dataset | |
compute_metrics=compute_metrics, | |
) | |
trainer.train() | |
# In[ ]: | |
trainable( | |
{ | |
"train_path": train_path, | |
"valid_path": valid_path, | |
"batch_size": 16, | |
"batch_size_eval": 64, | |
"warmup_steps": 500, | |
"weight_decay": 0.001, | |
"learning_rate": 5e-5, | |
"epochs": 3, | |
} | |
) | |
# # HPO | |
# In[ ]: | |
# config = { | |
# "train_path": train_path, | |
# "valid_path": valid_path, | |
# "warmup_steps": tune.randint(0, 500), | |
# "weight_decay": tune.loguniform(0.00001, 0.1), | |
# "learning_rate": tune.loguniform(5e-6, 5e-4), | |
# "epochs": tune.choice([2, 3, 4]) | |
# } | |
# # In[ ]: | |
# analysis = tune.run( | |
# trainable, | |
# config=config, | |
# metric="eval_acc", | |
# mode="max", | |
# #search_alg=HyperOptSearch(), | |
# #scheduler=ASHAScheduler(), | |
# num_samples=1, | |
# ) | |
# # In[ ]: | |
# def model_init(): | |
# return AutoModelForSequenceClassification.from_pretrained("dccuchile/bert-base-spanish-wwm-cased", num_labels=3) | |
# trainer = Trainer( | |
# args=training_args, # training arguments, defined above | |
# train_dataset=train_dataset, # training dataset | |
# eval_dataset=valid_dataset, # evaluation dataset | |
# model_init=model_init, | |
# compute_metrics=compute_metrics, | |
# ) | |
# # In[ ]: | |
# best_trial = trainer.hyperparameter_search( | |
# direction="maximize", | |
# backend="ray", | |
# n_trials=2, | |
# # Choose among many libraries: | |
# # https://docs.ray.io/en/latest/tune/api_docs/suggestion.html | |
# search_alg=HyperOptSearch(mode="max", metric="accuracy"), | |
# # Choose among schedulers: | |
# # https://docs.ray.io/en/latest/tune/api_docs/schedulers.html | |
# scheduler=ASHAScheduler(mode="max", metric="accuracy"), | |
# local_dir="tune_runs", | |
# ) | |