Milestone3 / finetuning.py
Jainesh212's picture
Create finetuning.py
500aba2
raw
history blame
7.01 kB
import numpy as np
import pandas as pd
import os
from tqdm.notebook import tqdm
import pandas as pd
from torch import cuda
import torch
import transformers
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertTokenizer
import shutil
device = 'cuda' if cuda.is_available() else 'cpu'
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
df_train = pd.read_csv("train.csv")
MAX_LEN = 512
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
EPOCHS = 2
LEARNING_RATE = 1e-05
df_train = df_train.sample(n=512)
train_size = 0.8
df_train_sampled = df_train.sample(frac=train_size, random_state=44)
df_val = df_train.drop(df_train_sampled.index).reset_index(drop=True)
df_train_sampled = df_train_sampled.reset_index(drop=True)
model_name = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizer.from_pretrained(model_name, do_lower_case=True)
class ToxicDataset(Dataset):
def __init__(self, data, tokenizer, max_len):
self.data = data
self.tokenizer = tokenizer
self.max_len = max_len
self.labels = self.data[label_cols].values
def __len__(self):
return len(self.data.id)
def __getitem__(self, idx):
text = self.data.comment_text
tokenized_text = self.tokenizer.encode_plus(
str( text ),
None,
add_special_tokens=True,
max_length=self.max_len,
padding='max_length',
return_token_type_ids=True,
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
return {
'input_ids': tokenized_text['input_ids'].flatten(),
'attention_mask': tokenized_text['attention_mask'].flatten(),
'targets': torch.FloatTensor(self.labels[idx])
}
train_dataset = ToxicDataset(df_train_sampled, tokenizer, MAX_LEN)
valid_dataset = ToxicDataset(df_val, tokenizer, MAX_LEN)
train_data_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=TRAIN_BATCH_SIZE,
shuffle=True,
num_workers=0
)
val_data_loader = torch.utils.data.DataLoader(valid_dataset,
batch_size=VALID_BATCH_SIZE,
shuffle=False,
num_workers=0
)
class CustomDistilBertClass(torch.nn.Module):
def __init__(self):
super(CustomDistilBertClass, self).__init__()
self.distilbert_model = DistilBertModel.from_pretrained(model_name, return_dict=True)
self.dropout = torch.nn.Dropout(0.3)
self.linear = torch.nn.Linear(768, 6)
def forward(self, input_ids, attn_mask):
output = self.distilbert_model(
input_ids,
attention_mask=attn_mask,
)
output_dropout = self.dropout(output.last_hidden_state)
output = self.linear(output_dropout)
return output
model = CustomDistilBertClass()
model.to(device)
def loss_fn(outputs, targets):
return torch.nn.BCEWithLogitsLoss()(outputs, targets)
optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
def train_model(n_epochs, training_loader, validation_loader, model,
optimizer, checkpoint_path, best_model_path):
valid_loss_min = np.Inf
for epoch in range(1, n_epochs+1):
train_loss = 0
valid_loss = 0
model.train()
print('############# Epoch {}: Training Start #############'.format(epoch))
for batch_idx, data in enumerate(training_loader):
ids = data['input_ids'].to(device, dtype = torch.long)
mask = data['attention_mask'].to(device, dtype = torch.long)
outputs = model(ids, mask, )
outputs = outputs[:, 0, :]
targets = data['targets'].to(device, dtype = torch.float)
loss = loss_fn(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.item() - train_loss))
print('############# Epoch {}: Training End #############'.format(epoch))
print('############# Epoch {}: Validation Start #############'.format(epoch))
model.eval()
with torch.no_grad():
for batch_idx, data in enumerate(validation_loader, 0):
ids = data['input_ids'].to(device, dtype = torch.long)
mask = data['attention_mask'].to(device, dtype = torch.long)
targets = data['targets'].to(device, dtype = torch.float)
outputs = model(ids, mask, )
outputs = outputs[:, 0, :]
loss = loss_fn(outputs, targets)
valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss.item() - valid_loss))
print('############# Epoch {}: Validation End #############'.format(epoch))
train_loss = train_loss/len(training_loader)
valid_loss = valid_loss/len(validation_loader)
print('Epoch: {} \tAvgerage Training Loss: {:.6f} \tAverage Validation Loss: {:.6f}'.format(
epoch,
train_loss,
valid_loss
))
checkpoint = {
'epoch': epoch + 1,
'valid_loss_min': valid_loss,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()
}
save_ckp(checkpoint, False, checkpoint_path, best_model_path)
if valid_loss <= valid_loss_min:
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(valid_loss_min,valid_loss))
save_ckp(checkpoint, True, checkpoint_path, best_model_path)
valid_loss_min = valid_loss
print('############# Epoch {} Done #############\n'.format(epoch))
return model
def load_ckp(checkpoint_fpath, model, optimizer):
"""
checkpoint_path: path to save checkpoint
model: model that we want to load checkpoint parameters into
optimizer: optimizer we defined in previous training
"""
checkpoint = torch.load(checkpoint_fpath)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
valid_loss_min = checkpoint['valid_loss_min']
return model, optimizer, checkpoint['epoch'], valid_loss_min.item()
def save_ckp(state, is_best, checkpoint_path, best_model_path):
"""
state: checkpoint we want to save
is_best: is this the best checkpoint; min validation loss
checkpoint_path: path to save checkpoint
best_model_path: path to save best model
"""
f_path = checkpoint_path
torch.save(state, f_path)
if is_best:
best_fpath = best_model_path
shutil.copyfile(f_path, best_fpath)
ckpt_path = "model.pt"
best_model_path = "best_model.pt"
trained_model = train_model(EPOCHS,
train_data_loader,
val_data_loader,
model,
optimizer,
ckpt_path,
best_model_path)