import pandas as pd import pytorch_lightning as pl import torch from torch.utils.data import DataLoader, Dataset from transformers import BertTokenizer from src.utils import get_target_columns def collate_fn(data): input_ids = [] token_type_ids = [] attention_mask = [] labels = [] for item in data: input_ids.append(item['input_ids'].squeeze()) token_type_ids.append(item['token_type_ids'].squeeze()) attention_mask.append(item['attention_mask'].squeeze()) labels.append(item['labels'].squeeze()) return { "input_ids": torch.stack(input_ids), 'token_type_ids': torch.stack(token_type_ids), 'attention_mask': torch.stack(attention_mask), 'labels': torch.stack(labels) } class ClassificationDataset(Dataset): def __init__(self, tokenizer: BertTokenizer, df: pd.DataFrame, config: dict): self.config = config self.tokenizer = tokenizer self.df = df self.features = self.tokenizer( text=df.full_text.tolist(), max_length=self.config['max_length'], padding=True, truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt', ) if 'cohesion' in self.df.columns: self.features['labels'] = torch.as_tensor(df[get_target_columns()].values, dtype=torch.float32) else: data = torch.ones(size=(len(df), 6), dtype=torch.float32) * -1. self.features['labels'] = data def __getitem__(self, item): """Returns dict with input_ids, token_type_ids, attention_mask, labels """ return { 'input_ids': self.features['input_ids'][item], 'token_type_ids': self.features['token_type_ids'][item], 'attention_mask': self.features['attention_mask'][item], 'labels': self.features['labels'][item] } def __len__(self): return len(self.df) class ClassificationDataloader(pl.LightningDataModule): def __init__( self, tokenizer: BertTokenizer, train_df: pd.DataFrame, val_df: pd.DataFrame, config: dict ): super().__init__() self.config = config self.train_data = ClassificationDataset(tokenizer, train_df, config) self.val_data = ClassificationDataset(tokenizer, val_df, config) def train_dataloader(self): return DataLoader( dataset=self.train_data, shuffle=True, batch_size=self.config['batch_size'], num_workers=self.config['num_workers'], collate_fn=collate_fn ) def val_dataloader(self): return DataLoader( dataset=self.val_data, shuffle=False, batch_size=self.config['batch_size'], num_workers=self.config['num_workers'], collate_fn=collate_fn )