Spaces:
Runtime error
Runtime error
File size: 4,992 Bytes
6124176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""
Script for decoding summarization models available through Huggingface Transformers.
Usage with Huggingface Datasets:
python generation.py --model <model name> --data_path <path to data in jsonl format>
Usage with custom datasets in JSONL format:
python generation.py --model <model name> --dataset <dataset name> --split <data split>
"""
#!/usr/bin/env python
# coding: utf-8
import argparse
import json
import os
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
BATCH_SIZE = 8
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BART_CNNDM_CHECKPOINT = 'facebook/bart-large-cnn'
BART_XSUM_CHECKPOINT = 'facebook/bart-large-xsum'
PEGASUS_CNNDM_CHECKPOINT = 'google/pegasus-cnn_dailymail'
PEGASUS_XSUM_CHECKPOINT = 'google/pegasus-xsum'
PEGASUS_NEWSROOM_CHECKPOINT = 'google/pegasus-newsroom'
PEGASUS_MULTINEWS_CHECKPOINT = 'google/pegasus-multi_news'
MODEL_CHECKPOINTS = {
'bart-xsum': BART_XSUM_CHECKPOINT,
'bart-cnndm': BART_CNNDM_CHECKPOINT,
'pegasus-xsum': PEGASUS_XSUM_CHECKPOINT,
'pegasus-cnndm': PEGASUS_CNNDM_CHECKPOINT,
'pegasus-newsroom': PEGASUS_NEWSROOM_CHECKPOINT,
'pegasus-multinews': PEGASUS_MULTINEWS_CHECKPOINT
}
class JSONDataset(torch.utils.data.Dataset):
def __init__(self, data_path):
super(JSONDataset, self).__init__()
with open(data_path) as fd:
self.data = [json.loads(line) for line in fd]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def preprocess_data(raw_data, dataset):
"""
Unify format of Huggingface Datastes
:param raw_data: loaded data
:param dataset: name of dataset
"""
if dataset == 'xsum':
raw_data['article'] = raw_data['document']
raw_data['target'] = raw_data['summary']
del raw_data['document']
del raw_data['summary']
elif dataset == 'cnndm':
raw_data['target'] = raw_data['highlights']
del raw_data['highlights']
elif dataset == 'gigaword':
raw_data['article'] = raw_data['document']
raw_data['target'] = raw_data['summary']
del raw_data['document']
del raw_data['summary']
return raw_data
def postprocess_data(raw_data, decoded):
"""
Remove generation artifacts and postprocess outputs
:param raw_data: loaded data
:param decoded: model outputs
"""
raw_data['target'] = [x.replace('\n', ' ') for x in raw_data['target']]
raw_data['decoded'] = [x.replace('<n>', ' ') for x in decoded]
return [dict(zip(raw_data, t)) for t in zip(*raw_data.values())]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--model', type=str, required=True, choices=['bart-xsum', 'bart-cnndm', 'pegasus-xsum', 'pegasus-cnndm', 'pegasus-newsroom', 'pegasus-multinews'])
parser.add_argument('--data_path', type=str)
parser.add_argument('--dataset', type=str, choices=['xsum', 'cnndm', 'gigaword'])
parser.add_argument('--split', type=str, choices=['train', 'validation', 'test'])
args = parser.parse_args()
if args.dataset and not args.split:
raise RuntimeError('If `dataset` flag is specified `split` must also be provided.')
if args.data_path:
args.dataset = os.path.splitext(os.path.basename(args.data_path))[0]
args.split = 'user'
# Load models & data
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINTS[args.model]).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINTS[args.model])
if not args.data_path:
if args.dataset == 'cnndm':
dataset = load_dataset('cnn_dailymail', '3.0.0', split=args.split)
elif args.dataset =='xsum':
dataset = load_dataset('xsum', split=args.split)
elif args.dataset =='gigaword':
dataset = load_dataset('gigaword', split=args.split)
else:
dataset = JSONDataset(args.data_path)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)
# Run validation
filename = '%s.%s.%s.results' % (args.model.replace("/", "-"), args.dataset, args.split)
fd_out = open(filename, 'w')
results = []
model.eval()
with torch.no_grad():
for raw_data in tqdm(dataloader):
raw_data = preprocess_data(raw_data, args.dataset)
batch = tokenizer(raw_data["article"], return_tensors="pt", truncation=True, padding="longest").to(DEVICE)
summaries = model.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask)
decoded = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
result = postprocess_data(raw_data, decoded)
results.extend(result)
for example in result:
fd_out.write(json.dumps(example) + '\n')
|