Spaces:
Runtime error
Runtime error
from pytorch_lightning import LightningDataModule | |
from typing import Optional | |
from torch.utils.data import DataLoader, DistributedSampler | |
def get_consume_samples(data_model: LightningDataModule) -> int: | |
if hasattr(data_model.trainer.lightning_module, 'consumed_samples'): | |
consumed_samples = data_model.trainer.lightning_module.consumed_samples | |
print('get consumed samples from model: {}'.format(consumed_samples)) | |
else: | |
world_size = data_model.trainer.world_size | |
consumed_samples = max(0, data_model.trainer.global_step - 1) * \ | |
data_model.hparams.train_batchsize * world_size * data_model.trainer.accumulate_grad_batches | |
print('calculate consumed samples: {}'.format(consumed_samples)) | |
return consumed_samples | |
class UniversalDataModule(LightningDataModule): | |
def add_data_specific_args(parent_args): | |
parser = parent_args.add_argument_group('Universal DataModule') | |
parser.add_argument('--num_workers', default=8, type=int) | |
parser.add_argument('--dataloader_workers', default=2, type=int) | |
parser.add_argument('--train_batchsize', default=32, type=int) | |
parser.add_argument('--val_batchsize', default=32, type=int) | |
parser.add_argument('--test_batchsize', default=32, type=int) | |
parser.add_argument('--datasets_name', type=str, default=None) | |
parser.add_argument('--train_datasets_field', type=str, default='train') | |
parser.add_argument('--val_datasets_field', type=str, default='validation') | |
parser.add_argument('--test_datasets_field', type=str, default='test') | |
parser.add_argument('--train_file', type=str, default=None) | |
parser.add_argument('--val_file', type=str, default=None) | |
parser.add_argument('--test_file', type=str, default=None) | |
parser.add_argument('--raw_file_type', type=str, default='json') | |
parser.add_argument('--sampler_type', type=str, | |
choices=['single', | |
'random'], | |
default='random') | |
return parent_args | |
def __init__( | |
self, | |
tokenizer, | |
collate_fn, | |
args, | |
datasets=None, | |
**kwargs, | |
): | |
super().__init__() | |
# 如果不传入datasets的名字,则可以在对象外部替换内部的datasets为模型需要的 | |
if datasets is not None: | |
self.datasets = datasets | |
elif args.datasets_name is not None: | |
from fengshen.data.fs_datasets import load_dataset | |
print('---------begin to load datasets {}'.format(args.datasets_name)) | |
self.datasets = load_dataset( | |
args.datasets_name, num_proc=args.num_workers) | |
print('---------ending load datasets {}'.format(args.datasets_name)) | |
else: | |
print('---------begin to load datasets from local file') | |
from datasets import load_dataset | |
self.datasets = load_dataset(args.raw_file_type, | |
data_files={ | |
args.train_datasets_field: args.train_file, | |
args.val_datasets_field: args.val_file, | |
args.test_datasets_field: args.test_file}) | |
print('---------end to load datasets from local file') | |
self.tokenizer = tokenizer | |
self.collate_fn = collate_fn | |
self.save_hyperparameters(args) | |
def get_custom_sampler(self, ds): | |
from .universal_sampler import PretrainingRandomSampler | |
from .universal_sampler import PretrainingSampler | |
world_size = self.trainer.world_size | |
consumed_samples = get_consume_samples(self) | |
# use the user default sampler | |
if self.hparams.sampler_type == 'random': | |
return PretrainingRandomSampler( | |
total_samples=len(ds), | |
# consumed_samples cal by global steps | |
consumed_samples=consumed_samples, | |
micro_batch_size=self.hparams.train_batchsize, | |
data_parallel_rank=self.trainer.global_rank, | |
data_parallel_size=world_size, | |
epoch=self.trainer.current_epoch, | |
) | |
elif self.hparams.sampler_type == 'single': | |
return PretrainingSampler( | |
total_samples=len(ds), | |
# consumed_samples cal by global steps | |
consumed_samples=consumed_samples, | |
micro_batch_size=self.hparams.train_batchsize, | |
data_parallel_rank=self.trainer.global_rank, | |
data_parallel_size=world_size, | |
) | |
else: | |
raise Exception('Unknown sampler type: {}'.format(self.hparams.sampler_type)) | |
def setup(self, stage: Optional[str] = None) -> None: | |
return | |
def train_dataloader(self): | |
ds = self.datasets[self.hparams.train_datasets_field] | |
collate_fn = self.collate_fn | |
if collate_fn is None and hasattr(ds, 'collater'): | |
collate_fn = ds.collater | |
if self.hparams.replace_sampler_ddp is False: | |
return DataLoader( | |
ds, | |
batch_sampler=self.get_custom_sampler(ds), | |
num_workers=self.hparams.dataloader_workers, | |
collate_fn=collate_fn, | |
pin_memory=True, | |
) | |
return DataLoader( | |
ds, | |
batch_size=self.hparams.train_batchsize, | |
num_workers=self.hparams.dataloader_workers, | |
collate_fn=collate_fn, | |
pin_memory=True, | |
) | |
def val_dataloader(self): | |
ds = self.datasets[self.hparams.val_datasets_field] | |
collate_fn = self.collate_fn | |
if collate_fn is None and hasattr(ds, 'collater'): | |
collate_fn = ds.collater | |
return DataLoader( | |
ds, | |
batch_size=self.hparams.val_batchsize, | |
shuffle=False, | |
num_workers=self.hparams.dataloader_workers, | |
collate_fn=collate_fn, | |
sampler=DistributedSampler( | |
ds, shuffle=False), | |
pin_memory=True, | |
) | |
def test_dataloader(self): | |
ds = self.datasets[self.hparams.test_datasets_field] | |
collate_fn = self.collate_fn | |
if collate_fn is None and hasattr(ds, 'collater'): | |
collate_fn = ds.collater | |
return DataLoader( | |
ds, | |
batch_size=self.hparams.test_batchsize, | |
shuffle=False, | |
num_workers=self.hparams.dataloader_workers, | |
collate_fn=collate_fn, | |
sampler=DistributedSampler( | |
ds, shuffle=False), | |
pin_memory=True, | |
) | |