HaloMaster's picture
add fengshen
50f0fbb
from pytorch_lightning import LightningDataModule
from typing import Optional
from torch.utils.data import DataLoader, DistributedSampler
def get_consume_samples(data_model: LightningDataModule) -> int:
if hasattr(data_model.trainer.lightning_module, 'consumed_samples'):
consumed_samples = data_model.trainer.lightning_module.consumed_samples
print('get consumed samples from model: {}'.format(consumed_samples))
else:
world_size = data_model.trainer.world_size
consumed_samples = max(0, data_model.trainer.global_step - 1) * \
data_model.hparams.train_batchsize * world_size * data_model.trainer.accumulate_grad_batches
print('calculate consumed samples: {}'.format(consumed_samples))
return consumed_samples
class UniversalDataModule(LightningDataModule):
@ staticmethod
def add_data_specific_args(parent_args):
parser = parent_args.add_argument_group('Universal DataModule')
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--dataloader_workers', default=2, type=int)
parser.add_argument('--train_batchsize', default=32, type=int)
parser.add_argument('--val_batchsize', default=32, type=int)
parser.add_argument('--test_batchsize', default=32, type=int)
parser.add_argument('--datasets_name', type=str, default=None)
parser.add_argument('--train_datasets_field', type=str, default='train')
parser.add_argument('--val_datasets_field', type=str, default='validation')
parser.add_argument('--test_datasets_field', type=str, default='test')
parser.add_argument('--train_file', type=str, default=None)
parser.add_argument('--val_file', type=str, default=None)
parser.add_argument('--test_file', type=str, default=None)
parser.add_argument('--raw_file_type', type=str, default='json')
parser.add_argument('--sampler_type', type=str,
choices=['single',
'random'],
default='random')
return parent_args
def __init__(
self,
tokenizer,
collate_fn,
args,
datasets=None,
**kwargs,
):
super().__init__()
# 如果不传入datasets的名字,则可以在对象外部替换内部的datasets为模型需要的
if datasets is not None:
self.datasets = datasets
elif args.datasets_name is not None:
from fengshen.data.fs_datasets import load_dataset
print('---------begin to load datasets {}'.format(args.datasets_name))
self.datasets = load_dataset(
args.datasets_name, num_proc=args.num_workers)
print('---------ending load datasets {}'.format(args.datasets_name))
else:
print('---------begin to load datasets from local file')
from datasets import load_dataset
self.datasets = load_dataset(args.raw_file_type,
data_files={
args.train_datasets_field: args.train_file,
args.val_datasets_field: args.val_file,
args.test_datasets_field: args.test_file})
print('---------end to load datasets from local file')
self.tokenizer = tokenizer
self.collate_fn = collate_fn
self.save_hyperparameters(args)
def get_custom_sampler(self, ds):
from .universal_sampler import PretrainingRandomSampler
from .universal_sampler import PretrainingSampler
world_size = self.trainer.world_size
consumed_samples = get_consume_samples(self)
# use the user default sampler
if self.hparams.sampler_type == 'random':
return PretrainingRandomSampler(
total_samples=len(ds),
# consumed_samples cal by global steps
consumed_samples=consumed_samples,
micro_batch_size=self.hparams.train_batchsize,
data_parallel_rank=self.trainer.global_rank,
data_parallel_size=world_size,
epoch=self.trainer.current_epoch,
)
elif self.hparams.sampler_type == 'single':
return PretrainingSampler(
total_samples=len(ds),
# consumed_samples cal by global steps
consumed_samples=consumed_samples,
micro_batch_size=self.hparams.train_batchsize,
data_parallel_rank=self.trainer.global_rank,
data_parallel_size=world_size,
)
else:
raise Exception('Unknown sampler type: {}'.format(self.hparams.sampler_type))
def setup(self, stage: Optional[str] = None) -> None:
return
def train_dataloader(self):
ds = self.datasets[self.hparams.train_datasets_field]
collate_fn = self.collate_fn
if collate_fn is None and hasattr(ds, 'collater'):
collate_fn = ds.collater
if self.hparams.replace_sampler_ddp is False:
return DataLoader(
ds,
batch_sampler=self.get_custom_sampler(ds),
num_workers=self.hparams.dataloader_workers,
collate_fn=collate_fn,
pin_memory=True,
)
return DataLoader(
ds,
batch_size=self.hparams.train_batchsize,
num_workers=self.hparams.dataloader_workers,
collate_fn=collate_fn,
pin_memory=True,
)
def val_dataloader(self):
ds = self.datasets[self.hparams.val_datasets_field]
collate_fn = self.collate_fn
if collate_fn is None and hasattr(ds, 'collater'):
collate_fn = ds.collater
return DataLoader(
ds,
batch_size=self.hparams.val_batchsize,
shuffle=False,
num_workers=self.hparams.dataloader_workers,
collate_fn=collate_fn,
sampler=DistributedSampler(
ds, shuffle=False),
pin_memory=True,
)
def test_dataloader(self):
ds = self.datasets[self.hparams.test_datasets_field]
collate_fn = self.collate_fn
if collate_fn is None and hasattr(ds, 'collater'):
collate_fn = ds.collater
return DataLoader(
ds,
batch_size=self.hparams.test_batchsize,
shuffle=False,
num_workers=self.hparams.dataloader_workers,
collate_fn=collate_fn,
sampler=DistributedSampler(
ds, shuffle=False),
pin_memory=True,
)