Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
import os | |
from tqdm.notebook import tqdm | |
import pandas as pd | |
from torch import cuda | |
import torch | |
import transformers | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import DistilBertModel, DistilBertTokenizer | |
import shutil | |
device = 'cuda' if cuda.is_available() else 'cpu' | |
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'] | |
df_train = pd.read_csv("train.csv") | |
MAX_LEN = 512 | |
TRAIN_BATCH_SIZE = 32 | |
VALID_BATCH_SIZE = 32 | |
EPOCHS = 2 | |
LEARNING_RATE = 1e-05 | |
df_train = df_train.sample(n=512) | |
train_size = 0.8 | |
df_train_sampled = df_train.sample(frac=train_size, random_state=44) | |
df_val = df_train.drop(df_train_sampled.index).reset_index(drop=True) | |
df_train_sampled = df_train_sampled.reset_index(drop=True) | |
model_name = 'distilbert-base-uncased' | |
tokenizer = DistilBertTokenizer.from_pretrained(model_name, do_lower_case=True) | |
class ToxicDataset(Dataset): | |
def __init__(self, data, tokenizer, max_len): | |
self.data = data | |
self.tokenizer = tokenizer | |
self.max_len = max_len | |
self.labels = self.data[label_cols].values | |
def __len__(self): | |
return len(self.data.id) | |
def __getitem__(self, idx): | |
text = self.data.comment_text | |
tokenized_text = self.tokenizer.encode_plus( | |
str( text ), | |
None, | |
add_special_tokens=True, | |
max_length=self.max_len, | |
padding='max_length', | |
return_token_type_ids=True, | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
return { | |
'input_ids': tokenized_text['input_ids'].flatten(), | |
'attention_mask': tokenized_text['attention_mask'].flatten(), | |
'targets': torch.FloatTensor(self.labels[idx]) | |
} | |
train_dataset = ToxicDataset(df_train_sampled, tokenizer, MAX_LEN) | |
valid_dataset = ToxicDataset(df_val, tokenizer, MAX_LEN) | |
train_data_loader = torch.utils.data.DataLoader(train_dataset, | |
batch_size=TRAIN_BATCH_SIZE, | |
shuffle=True, | |
num_workers=0 | |
) | |
val_data_loader = torch.utils.data.DataLoader(valid_dataset, | |
batch_size=VALID_BATCH_SIZE, | |
shuffle=False, | |
num_workers=0 | |
) | |
class CustomDistilBertClass(torch.nn.Module): | |
def __init__(self): | |
super(CustomDistilBertClass, self).__init__() | |
self.distilbert_model = DistilBertModel.from_pretrained(model_name, return_dict=True) | |
self.dropout = torch.nn.Dropout(0.3) | |
self.linear = torch.nn.Linear(768, 6) | |
def forward(self, input_ids, attn_mask): | |
output = self.distilbert_model( | |
input_ids, | |
attention_mask=attn_mask, | |
) | |
output_dropout = self.dropout(output.last_hidden_state) | |
output = self.linear(output_dropout) | |
return output | |
model = CustomDistilBertClass() | |
model.to(device) | |
def loss_fn(outputs, targets): | |
return torch.nn.BCEWithLogitsLoss()(outputs, targets) | |
optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE) | |
def train_model(n_epochs, training_loader, validation_loader, model, | |
optimizer, checkpoint_path, best_model_path): | |
valid_loss_min = np.Inf | |
for epoch in range(1, n_epochs+1): | |
train_loss = 0 | |
valid_loss = 0 | |
model.train() | |
print('############# Epoch {}: Training Start #############'.format(epoch)) | |
for batch_idx, data in enumerate(training_loader): | |
ids = data['input_ids'].to(device, dtype = torch.long) | |
mask = data['attention_mask'].to(device, dtype = torch.long) | |
outputs = model(ids, mask, ) | |
outputs = outputs[:, 0, :] | |
targets = data['targets'].to(device, dtype = torch.float) | |
loss = loss_fn(outputs, targets) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.item() - train_loss)) | |
print('############# Epoch {}: Training End #############'.format(epoch)) | |
print('############# Epoch {}: Validation Start #############'.format(epoch)) | |
model.eval() | |
with torch.no_grad(): | |
for batch_idx, data in enumerate(validation_loader, 0): | |
ids = data['input_ids'].to(device, dtype = torch.long) | |
mask = data['attention_mask'].to(device, dtype = torch.long) | |
targets = data['targets'].to(device, dtype = torch.float) | |
outputs = model(ids, mask, ) | |
outputs = outputs[:, 0, :] | |
loss = loss_fn(outputs, targets) | |
valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss.item() - valid_loss)) | |
print('############# Epoch {}: Validation End #############'.format(epoch)) | |
train_loss = train_loss/len(training_loader) | |
valid_loss = valid_loss/len(validation_loader) | |
print('Epoch: {} \tAvgerage Training Loss: {:.6f} \tAverage Validation Loss: {:.6f}'.format( | |
epoch, | |
train_loss, | |
valid_loss | |
)) | |
checkpoint = { | |
'epoch': epoch + 1, | |
'valid_loss_min': valid_loss, | |
'state_dict': model.state_dict(), | |
'optimizer': optimizer.state_dict() | |
} | |
save_ckp(checkpoint, False, checkpoint_path, best_model_path) | |
if valid_loss <= valid_loss_min: | |
print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...'.format(valid_loss_min,valid_loss)) | |
save_ckp(checkpoint, True, checkpoint_path, best_model_path) | |
valid_loss_min = valid_loss | |
print('############# Epoch {} Done #############\n'.format(epoch)) | |
return model | |
def load_ckp(checkpoint_fpath, model, optimizer): | |
""" | |
checkpoint_path: path to save checkpoint | |
model: model that we want to load checkpoint parameters into | |
optimizer: optimizer we defined in previous training | |
""" | |
checkpoint = torch.load(checkpoint_fpath) | |
model.load_state_dict(checkpoint['state_dict']) | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
valid_loss_min = checkpoint['valid_loss_min'] | |
return model, optimizer, checkpoint['epoch'], valid_loss_min.item() | |
def save_ckp(state, is_best, checkpoint_path, best_model_path): | |
""" | |
state: checkpoint we want to save | |
is_best: is this the best checkpoint; min validation loss | |
checkpoint_path: path to save checkpoint | |
best_model_path: path to save best model | |
""" | |
f_path = checkpoint_path | |
torch.save(state, f_path) | |
if is_best: | |
best_fpath = best_model_path | |
shutil.copyfile(f_path, best_fpath) | |
ckpt_path = "model.pt" | |
best_model_path = "best_model.pt" | |
trained_model = train_model(EPOCHS, | |
train_data_loader, | |
val_data_loader, | |
model, | |
optimizer, | |
ckpt_path, | |
best_model_path) |