GitHub Action
refs/heads/ci-cd/hugging-face
8b414b0
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer
from src.utils import get_target_columns
def collate_fn(data):
input_ids = []
token_type_ids = []
attention_mask = []
labels = []
for item in data:
input_ids.append(item['input_ids'].squeeze())
token_type_ids.append(item['token_type_ids'].squeeze())
attention_mask.append(item['attention_mask'].squeeze())
labels.append(item['labels'].squeeze())
return {
"input_ids": torch.stack(input_ids),
'token_type_ids': torch.stack(token_type_ids),
'attention_mask': torch.stack(attention_mask),
'labels': torch.stack(labels)
}
class ClassificationDataset(Dataset):
def __init__(self, tokenizer: BertTokenizer, df: pd.DataFrame, config: dict):
self.config = config
self.tokenizer = tokenizer
self.df = df
self.features = self.tokenizer(
text=df.full_text.tolist(),
max_length=self.config['max_length'],
padding=True,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors='pt',
)
if 'cohesion' in self.df.columns:
self.features['labels'] = torch.as_tensor(df[get_target_columns()].values, dtype=torch.float32)
else:
data = torch.ones(size=(len(df), 6), dtype=torch.float32) * -1.
self.features['labels'] = data
def __getitem__(self, item):
"""Returns dict with input_ids, token_type_ids, attention_mask, labels
"""
return {
'input_ids': self.features['input_ids'][item],
'token_type_ids': self.features['token_type_ids'][item],
'attention_mask': self.features['attention_mask'][item],
'labels': self.features['labels'][item]
}
def __len__(self):
return len(self.df)
class ClassificationDataloader(pl.LightningDataModule):
def __init__(
self,
tokenizer: BertTokenizer,
train_df: pd.DataFrame,
val_df: pd.DataFrame,
config: dict
):
super().__init__()
self.config = config
self.train_data = ClassificationDataset(tokenizer, train_df, config)
self.val_data = ClassificationDataset(tokenizer, val_df, config)
def train_dataloader(self):
return DataLoader(
dataset=self.train_data,
shuffle=True,
batch_size=self.config['batch_size'],
num_workers=self.config['num_workers'],
collate_fn=collate_fn
)
def val_dataloader(self):
return DataLoader(
dataset=self.val_data,
shuffle=False,
batch_size=self.config['batch_size'],
num_workers=self.config['num_workers'],
collate_fn=collate_fn
)