Upload 25 files
Browse files- utils/__init__.py +0 -0
- utils/__pycache__/bert_model.cpython-39.pyc +0 -0
- utils/__pycache__/callbacks.cpython-39.pyc +0 -0
- utils/__pycache__/file_utils.cpython-39.pyc +0 -0
- utils/__pycache__/finetune.cpython-39.pyc +0 -0
- utils/__pycache__/lightning_base.cpython-39.pyc +0 -0
- utils/__pycache__/sentence_retrieval_model.cpython-39.pyc +0 -0
- utils/__pycache__/sentence_retrieval_module.cpython-39.pyc +0 -0
- utils/__pycache__/textual_entailment_module.cpython-39.pyc +0 -0
- utils/__pycache__/utils_graph2text.cpython-39.pyc +0 -0
- utils/__pycache__/utils_verbalisation_module.cpython-39.pyc +0 -0
- utils/__pycache__/verbalisation_module.cpython-39.pyc +0 -0
- utils/__pycache__/wikidata_utils.cpython-39.pyc +0 -0
- utils/bert_model.py +775 -0
- utils/callbacks.py +140 -0
- utils/file_utils.py +249 -0
- utils/finetune.py +633 -0
- utils/lightning_base.py +418 -0
- utils/sentence_retrieval_model.py +20 -0
- utils/sentence_retrieval_module.py +77 -0
- utils/textual_entailment_module.py +94 -0
- utils/utils_graph2text.py +114 -0
- utils/utils_verbalisation_module.py +610 -0
- utils/verbalisation_module.py +300 -0
- utils/wikidata_utils.py +173 -0
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/bert_model.cpython-39.pyc
ADDED
Binary file (30.6 kB). View file
|
|
utils/__pycache__/callbacks.cpython-39.pyc
ADDED
Binary file (4.9 kB). View file
|
|
utils/__pycache__/file_utils.cpython-39.pyc
ADDED
Binary file (6.81 kB). View file
|
|
utils/__pycache__/finetune.cpython-39.pyc
ADDED
Binary file (20.2 kB). View file
|
|
utils/__pycache__/lightning_base.cpython-39.pyc
ADDED
Binary file (13.5 kB). View file
|
|
utils/__pycache__/sentence_retrieval_model.cpython-39.pyc
ADDED
Binary file (1.11 kB). View file
|
|
utils/__pycache__/sentence_retrieval_module.cpython-39.pyc
ADDED
Binary file (2.52 kB). View file
|
|
utils/__pycache__/textual_entailment_module.cpython-39.pyc
ADDED
Binary file (2.65 kB). View file
|
|
utils/__pycache__/utils_graph2text.cpython-39.pyc
ADDED
Binary file (3.12 kB). View file
|
|
utils/__pycache__/utils_verbalisation_module.cpython-39.pyc
ADDED
Binary file (23.9 kB). View file
|
|
utils/__pycache__/verbalisation_module.cpython-39.pyc
ADDED
Binary file (7.37 kB). View file
|
|
utils/__pycache__/wikidata_utils.cpython-39.pyc
ADDED
Binary file (5.29 kB). View file
|
|
utils/bert_model.py
ADDED
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
19 |
+
|
20 |
+
import copy
|
21 |
+
import json
|
22 |
+
import logging
|
23 |
+
import math
|
24 |
+
import os
|
25 |
+
import shutil
|
26 |
+
import tarfile
|
27 |
+
import tempfile
|
28 |
+
import sys
|
29 |
+
from io import open
|
30 |
+
|
31 |
+
import torch
|
32 |
+
from torch import nn
|
33 |
+
from torch.nn import CrossEntropyLoss
|
34 |
+
|
35 |
+
from utils.file_utils import cached_path
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
40 |
+
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
|
41 |
+
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
|
42 |
+
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
|
43 |
+
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
|
44 |
+
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
|
45 |
+
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
46 |
+
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
47 |
+
}
|
48 |
+
CONFIG_NAME = 'bert_config.json'
|
49 |
+
WEIGHTS_NAME = 'pytorch_model.bin'
|
50 |
+
TF_WEIGHTS_NAME = 'model.ckpt'
|
51 |
+
|
52 |
+
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
53 |
+
""" Load tf checkpoints in a pytorch model
|
54 |
+
"""
|
55 |
+
try:
|
56 |
+
import re
|
57 |
+
import numpy as np
|
58 |
+
import tensorflow as tf
|
59 |
+
except ImportError:
|
60 |
+
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
61 |
+
"https://www.tensorflow.org/install/ for installation instructions.")
|
62 |
+
raise
|
63 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
64 |
+
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
65 |
+
# Load weights from TF model
|
66 |
+
init_vars = tf.train.list_variables(tf_path)
|
67 |
+
names = []
|
68 |
+
arrays = []
|
69 |
+
for name, shape in init_vars:
|
70 |
+
print("Loading TF weight {} with shape {}".format(name, shape))
|
71 |
+
array = tf.train.load_variable(tf_path, name)
|
72 |
+
names.append(name)
|
73 |
+
arrays.append(array)
|
74 |
+
|
75 |
+
for name, array in zip(names, arrays):
|
76 |
+
name = name.split('/')
|
77 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
78 |
+
# which are not required for using pretrained model
|
79 |
+
if any(n in ["adam_v", "adam_m"] for n in name):
|
80 |
+
print("Skipping {}".format("/".join(name)))
|
81 |
+
continue
|
82 |
+
pointer = model
|
83 |
+
for m_name in name:
|
84 |
+
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
85 |
+
l = re.split(r'_(\d+)', m_name)
|
86 |
+
else:
|
87 |
+
l = [m_name]
|
88 |
+
if l[0] == 'kernel' or l[0] == 'gamma':
|
89 |
+
pointer = getattr(pointer, 'weight')
|
90 |
+
elif l[0] == 'output_bias' or l[0] == 'beta':
|
91 |
+
pointer = getattr(pointer, 'bias')
|
92 |
+
elif l[0] == 'output_weights':
|
93 |
+
pointer = getattr(pointer, 'weight')
|
94 |
+
else:
|
95 |
+
pointer = getattr(pointer, l[0])
|
96 |
+
if len(l) >= 2:
|
97 |
+
num = int(l[1])
|
98 |
+
pointer = pointer[num]
|
99 |
+
if m_name[-11:] == '_embeddings':
|
100 |
+
pointer = getattr(pointer, 'weight')
|
101 |
+
elif m_name == 'kernel':
|
102 |
+
array = np.transpose(array)
|
103 |
+
try:
|
104 |
+
assert pointer.shape == array.shape
|
105 |
+
except AssertionError as e:
|
106 |
+
e.args += (pointer.shape, array.shape)
|
107 |
+
raise
|
108 |
+
print("Initialize PyTorch weight {}".format(name))
|
109 |
+
pointer.data = torch.from_numpy(array)
|
110 |
+
return model
|
111 |
+
|
112 |
+
|
113 |
+
def gelu(x):
|
114 |
+
"""Implementation of the gelu activation function.
|
115 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
116 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
117 |
+
Also see https://arxiv.org/abs/1606.08415
|
118 |
+
"""
|
119 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
120 |
+
|
121 |
+
|
122 |
+
def swish(x):
|
123 |
+
return x * torch.sigmoid(x)
|
124 |
+
|
125 |
+
|
126 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
127 |
+
|
128 |
+
|
129 |
+
class BertConfig(object):
|
130 |
+
"""Configuration class to store the configuration of a `BertModel`.
|
131 |
+
"""
|
132 |
+
def __init__(self,
|
133 |
+
vocab_size_or_config_json_file,
|
134 |
+
hidden_size=768,
|
135 |
+
num_hidden_layers=12,
|
136 |
+
num_attention_heads=12,
|
137 |
+
intermediate_size=3072,
|
138 |
+
hidden_act="gelu",
|
139 |
+
hidden_dropout_prob=0.1,
|
140 |
+
attention_probs_dropout_prob=0.1,
|
141 |
+
max_position_embeddings=512,
|
142 |
+
type_vocab_size=2,
|
143 |
+
initializer_range=0.02):
|
144 |
+
"""Constructs BertConfig.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
148 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
149 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
150 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
151 |
+
the Transformer encoder.
|
152 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
153 |
+
layer in the Transformer encoder.
|
154 |
+
hidden_act: The non-linear activation function (function or string) in the
|
155 |
+
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
156 |
+
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
157 |
+
layers in the embeddings, encoder, and pooler.
|
158 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
159 |
+
probabilities.
|
160 |
+
max_position_embeddings: The maximum sequence length that this model might
|
161 |
+
ever be used with. Typically set this to something large just in case
|
162 |
+
(e.g., 512 or 1024 or 2048).
|
163 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
164 |
+
`BertModel`.
|
165 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
166 |
+
initializing all weight matrices.
|
167 |
+
"""
|
168 |
+
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
169 |
+
and isinstance(vocab_size_or_config_json_file, unicode)):
|
170 |
+
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
171 |
+
json_config = json.loads(reader.read())
|
172 |
+
for key, value in json_config.items():
|
173 |
+
self.__dict__[key] = value
|
174 |
+
elif isinstance(vocab_size_or_config_json_file, int):
|
175 |
+
self.vocab_size = vocab_size_or_config_json_file
|
176 |
+
self.hidden_size = hidden_size
|
177 |
+
self.num_hidden_layers = num_hidden_layers
|
178 |
+
self.num_attention_heads = num_attention_heads
|
179 |
+
self.hidden_act = hidden_act
|
180 |
+
self.intermediate_size = intermediate_size
|
181 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
182 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
183 |
+
self.max_position_embeddings = max_position_embeddings
|
184 |
+
self.type_vocab_size = type_vocab_size
|
185 |
+
self.initializer_range = initializer_range
|
186 |
+
else:
|
187 |
+
raise ValueError("First argument must be either a vocabulary size (int)"
|
188 |
+
"or the path to a pretrained model config file (str)")
|
189 |
+
|
190 |
+
@classmethod
|
191 |
+
def from_dict(cls, json_object):
|
192 |
+
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
193 |
+
config = BertConfig(vocab_size_or_config_json_file=-1)
|
194 |
+
for key, value in json_object.items():
|
195 |
+
config.__dict__[key] = value
|
196 |
+
return config
|
197 |
+
|
198 |
+
@classmethod
|
199 |
+
def from_json_file(cls, json_file):
|
200 |
+
"""Constructs a `BertConfig` from a json file of parameters."""
|
201 |
+
with open(json_file, "r", encoding='utf-8') as reader:
|
202 |
+
text = reader.read()
|
203 |
+
return cls.from_dict(json.loads(text))
|
204 |
+
|
205 |
+
def __repr__(self):
|
206 |
+
return str(self.to_json_string())
|
207 |
+
|
208 |
+
def to_dict(self):
|
209 |
+
"""Serializes this instance to a Python dictionary."""
|
210 |
+
output = copy.deepcopy(self.__dict__)
|
211 |
+
return output
|
212 |
+
|
213 |
+
def to_json_string(self):
|
214 |
+
"""Serializes this instance to a JSON string."""
|
215 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
216 |
+
|
217 |
+
try:
|
218 |
+
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
219 |
+
except ImportError:
|
220 |
+
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
|
221 |
+
class BertLayerNorm(nn.Module):
|
222 |
+
def __init__(self, hidden_size, eps=1e-12):
|
223 |
+
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
224 |
+
"""
|
225 |
+
super(BertLayerNorm, self).__init__()
|
226 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
227 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
228 |
+
self.variance_epsilon = eps
|
229 |
+
|
230 |
+
def forward(self, x):
|
231 |
+
u = x.mean(-1, keepdim=True)
|
232 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
233 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
234 |
+
return self.weight * x + self.bias
|
235 |
+
|
236 |
+
class BertEmbeddings(nn.Module):
|
237 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
238 |
+
"""
|
239 |
+
def __init__(self, config):
|
240 |
+
super(BertEmbeddings, self).__init__()
|
241 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
242 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
243 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
244 |
+
|
245 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
246 |
+
# any TensorFlow checkpoint file
|
247 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
248 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
249 |
+
|
250 |
+
def forward(self, input_ids, token_type_ids=None):
|
251 |
+
seq_length = input_ids.size(1)
|
252 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
253 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
254 |
+
if token_type_ids is None:
|
255 |
+
token_type_ids = torch.zeros_like(input_ids)
|
256 |
+
|
257 |
+
words_embeddings = self.word_embeddings(input_ids)
|
258 |
+
position_embeddings = self.position_embeddings(position_ids)
|
259 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
260 |
+
|
261 |
+
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
262 |
+
embeddings = self.LayerNorm(embeddings)
|
263 |
+
embeddings = self.dropout(embeddings)
|
264 |
+
return embeddings
|
265 |
+
|
266 |
+
|
267 |
+
class BertSelfAttention(nn.Module):
|
268 |
+
def __init__(self, config):
|
269 |
+
super(BertSelfAttention, self).__init__()
|
270 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
271 |
+
raise ValueError(
|
272 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
273 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
274 |
+
self.num_attention_heads = config.num_attention_heads
|
275 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
276 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
277 |
+
|
278 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
279 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
280 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
281 |
+
|
282 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
283 |
+
|
284 |
+
def transpose_for_scores(self, x):
|
285 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
286 |
+
x = x.view(*new_x_shape)
|
287 |
+
return x.permute(0, 2, 1, 3)
|
288 |
+
|
289 |
+
def forward(self, hidden_states, attention_mask):
|
290 |
+
mixed_query_layer = self.query(hidden_states)
|
291 |
+
mixed_key_layer = self.key(hidden_states)
|
292 |
+
mixed_value_layer = self.value(hidden_states)
|
293 |
+
|
294 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
295 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
296 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
297 |
+
|
298 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
299 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
300 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
301 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
302 |
+
attention_scores = attention_scores + attention_mask
|
303 |
+
|
304 |
+
# Normalize the attention scores to probabilities.
|
305 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
306 |
+
|
307 |
+
# This is actually dropping out entire tokens to attend to, which might
|
308 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
309 |
+
attention_probs = self.dropout(attention_probs)
|
310 |
+
|
311 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
312 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
313 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
314 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
315 |
+
return context_layer
|
316 |
+
|
317 |
+
|
318 |
+
class BertSelfOutput(nn.Module):
|
319 |
+
def __init__(self, config):
|
320 |
+
super(BertSelfOutput, self).__init__()
|
321 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
322 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
323 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
324 |
+
|
325 |
+
def forward(self, hidden_states, input_tensor):
|
326 |
+
hidden_states = self.dense(hidden_states)
|
327 |
+
hidden_states = self.dropout(hidden_states)
|
328 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
329 |
+
return hidden_states
|
330 |
+
|
331 |
+
|
332 |
+
class BertAttention(nn.Module):
|
333 |
+
def __init__(self, config):
|
334 |
+
super(BertAttention, self).__init__()
|
335 |
+
self.self = BertSelfAttention(config)
|
336 |
+
self.output = BertSelfOutput(config)
|
337 |
+
|
338 |
+
def forward(self, input_tensor, attention_mask):
|
339 |
+
self_output = self.self(input_tensor, attention_mask)
|
340 |
+
attention_output = self.output(self_output, input_tensor)
|
341 |
+
return attention_output
|
342 |
+
|
343 |
+
|
344 |
+
class BertIntermediate(nn.Module):
|
345 |
+
def __init__(self, config):
|
346 |
+
super(BertIntermediate, self).__init__()
|
347 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
348 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
349 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
350 |
+
else:
|
351 |
+
self.intermediate_act_fn = config.hidden_act
|
352 |
+
|
353 |
+
def forward(self, hidden_states):
|
354 |
+
hidden_states = self.dense(hidden_states)
|
355 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
356 |
+
return hidden_states
|
357 |
+
|
358 |
+
|
359 |
+
class BertOutput(nn.Module):
|
360 |
+
def __init__(self, config):
|
361 |
+
super(BertOutput, self).__init__()
|
362 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
363 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
364 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
365 |
+
|
366 |
+
def forward(self, hidden_states, input_tensor):
|
367 |
+
hidden_states = self.dense(hidden_states)
|
368 |
+
hidden_states = self.dropout(hidden_states)
|
369 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
370 |
+
return hidden_states
|
371 |
+
|
372 |
+
|
373 |
+
class BertLayer(nn.Module):
|
374 |
+
def __init__(self, config):
|
375 |
+
super(BertLayer, self).__init__()
|
376 |
+
self.attention = BertAttention(config)
|
377 |
+
self.intermediate = BertIntermediate(config)
|
378 |
+
self.output = BertOutput(config)
|
379 |
+
|
380 |
+
def forward(self, hidden_states, attention_mask):
|
381 |
+
attention_output = self.attention(hidden_states, attention_mask)
|
382 |
+
intermediate_output = self.intermediate(attention_output)
|
383 |
+
layer_output = self.output(intermediate_output, attention_output)
|
384 |
+
return layer_output
|
385 |
+
|
386 |
+
|
387 |
+
class BertEncoder(nn.Module):
|
388 |
+
def __init__(self, config):
|
389 |
+
super(BertEncoder, self).__init__()
|
390 |
+
layer = BertLayer(config)
|
391 |
+
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
392 |
+
|
393 |
+
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
394 |
+
all_encoder_layers = []
|
395 |
+
for layer_module in self.layer:
|
396 |
+
hidden_states = layer_module(hidden_states, attention_mask)
|
397 |
+
if output_all_encoded_layers:
|
398 |
+
all_encoder_layers.append(hidden_states)
|
399 |
+
if not output_all_encoded_layers:
|
400 |
+
all_encoder_layers.append(hidden_states)
|
401 |
+
return all_encoder_layers
|
402 |
+
|
403 |
+
|
404 |
+
class BertPooler(nn.Module):
|
405 |
+
def __init__(self, config):
|
406 |
+
super(BertPooler, self).__init__()
|
407 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
408 |
+
self.activation = nn.Tanh()
|
409 |
+
|
410 |
+
def forward(self, hidden_states):
|
411 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
412 |
+
# to the first token.
|
413 |
+
first_token_tensor = hidden_states[:, 0]
|
414 |
+
pooled_output = self.dense(first_token_tensor)
|
415 |
+
pooled_output = self.activation(pooled_output)
|
416 |
+
return pooled_output
|
417 |
+
|
418 |
+
|
419 |
+
class BertPredictionHeadTransform(nn.Module):
|
420 |
+
def __init__(self, config):
|
421 |
+
super(BertPredictionHeadTransform, self).__init__()
|
422 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
423 |
+
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
424 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
425 |
+
else:
|
426 |
+
self.transform_act_fn = config.hidden_act
|
427 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
428 |
+
|
429 |
+
def forward(self, hidden_states):
|
430 |
+
hidden_states = self.dense(hidden_states)
|
431 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
432 |
+
hidden_states = self.LayerNorm(hidden_states)
|
433 |
+
return hidden_states
|
434 |
+
|
435 |
+
|
436 |
+
class BertLMPredictionHead(nn.Module):
|
437 |
+
def __init__(self, config, bert_model_embedding_weights):
|
438 |
+
super(BertLMPredictionHead, self).__init__()
|
439 |
+
self.transform = BertPredictionHeadTransform(config)
|
440 |
+
|
441 |
+
# The output weights are the same as the input embeddings, but there is
|
442 |
+
# an output-only bias for each token.
|
443 |
+
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
|
444 |
+
bert_model_embedding_weights.size(0),
|
445 |
+
bias=False)
|
446 |
+
self.decoder.weight = bert_model_embedding_weights
|
447 |
+
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
|
448 |
+
|
449 |
+
def forward(self, hidden_states):
|
450 |
+
hidden_states = self.transform(hidden_states)
|
451 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
452 |
+
return hidden_states
|
453 |
+
|
454 |
+
|
455 |
+
class BertOnlyMLMHead(nn.Module):
|
456 |
+
def __init__(self, config, bert_model_embedding_weights):
|
457 |
+
super(BertOnlyMLMHead, self).__init__()
|
458 |
+
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
459 |
+
|
460 |
+
def forward(self, sequence_output):
|
461 |
+
prediction_scores = self.predictions(sequence_output)
|
462 |
+
return prediction_scores
|
463 |
+
|
464 |
+
|
465 |
+
class BertOnlyNSPHead(nn.Module):
|
466 |
+
def __init__(self, config):
|
467 |
+
super(BertOnlyNSPHead, self).__init__()
|
468 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
469 |
+
|
470 |
+
def forward(self, pooled_output):
|
471 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
472 |
+
return seq_relationship_score
|
473 |
+
|
474 |
+
|
475 |
+
class BertPreTrainingHeads(nn.Module):
|
476 |
+
def __init__(self, config, bert_model_embedding_weights):
|
477 |
+
super(BertPreTrainingHeads, self).__init__()
|
478 |
+
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
|
479 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
480 |
+
|
481 |
+
def forward(self, sequence_output, pooled_output):
|
482 |
+
prediction_scores = self.predictions(sequence_output)
|
483 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
484 |
+
return prediction_scores, seq_relationship_score
|
485 |
+
|
486 |
+
|
487 |
+
class BertPreTrainedModel(nn.Module):
|
488 |
+
""" An abstract class to handle weights initialization and
|
489 |
+
a simple interface for dowloading and loading pretrained models.
|
490 |
+
"""
|
491 |
+
def __init__(self, config, *inputs, **kwargs):
|
492 |
+
super(BertPreTrainedModel, self).__init__()
|
493 |
+
if not isinstance(config, BertConfig):
|
494 |
+
raise ValueError(
|
495 |
+
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
496 |
+
"To create a model from a Google pretrained model use "
|
497 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
498 |
+
self.__class__.__name__, self.__class__.__name__
|
499 |
+
))
|
500 |
+
self.config = config
|
501 |
+
|
502 |
+
def init_bert_weights(self, module):
|
503 |
+
""" Initialize the weights.
|
504 |
+
"""
|
505 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
506 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
507 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
508 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
509 |
+
elif isinstance(module, BertLayerNorm):
|
510 |
+
module.bias.data.zero_()
|
511 |
+
module.weight.data.fill_(1.0)
|
512 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
513 |
+
module.bias.data.zero_()
|
514 |
+
|
515 |
+
@classmethod
|
516 |
+
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
|
517 |
+
from_tf=False, *inputs, **kwargs):
|
518 |
+
"""
|
519 |
+
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
520 |
+
Download and cache the pre-trained model file if needed.
|
521 |
+
|
522 |
+
Params:
|
523 |
+
pretrained_model_name_or_path: either:
|
524 |
+
- a str with the name of a pre-trained model to load selected in the list of:
|
525 |
+
. `bert-base-uncased`
|
526 |
+
. `bert-large-uncased`
|
527 |
+
. `bert-base-cased`
|
528 |
+
. `bert-large-cased`
|
529 |
+
. `bert-base-multilingual-uncased`
|
530 |
+
. `bert-base-multilingual-cased`
|
531 |
+
. `bert-base-chinese`
|
532 |
+
- a path or url to a pretrained model archive containing:
|
533 |
+
. `bert_config.json` a configuration file for the model
|
534 |
+
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
|
535 |
+
- a path or url to a pretrained model archive containing:
|
536 |
+
. `bert_config.json` a configuration file for the model
|
537 |
+
. `model.chkpt` a TensorFlow checkpoint
|
538 |
+
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
|
539 |
+
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
|
540 |
+
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
|
541 |
+
*inputs, **kwargs: additional input for the specific Bert class
|
542 |
+
(ex: num_labels for BertForSequenceClassification)
|
543 |
+
"""
|
544 |
+
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
545 |
+
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
546 |
+
else:
|
547 |
+
archive_file = pretrained_model_name_or_path
|
548 |
+
# redirect to the cache, if necessary
|
549 |
+
try:
|
550 |
+
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
551 |
+
except EnvironmentError:
|
552 |
+
logger.error(
|
553 |
+
"Model name '{}' was not found in model name list ({}). "
|
554 |
+
"We assumed '{}' was a path or url but couldn't find any file "
|
555 |
+
"associated to this path or url.".format(
|
556 |
+
pretrained_model_name_or_path,
|
557 |
+
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
|
558 |
+
archive_file))
|
559 |
+
return None
|
560 |
+
if resolved_archive_file == archive_file:
|
561 |
+
logger.info("loading archive file {}".format(archive_file))
|
562 |
+
else:
|
563 |
+
logger.info("loading archive file {} from cache at {}".format(
|
564 |
+
archive_file, resolved_archive_file))
|
565 |
+
tempdir = None
|
566 |
+
if os.path.isdir(resolved_archive_file) or from_tf:
|
567 |
+
serialization_dir = resolved_archive_file
|
568 |
+
else:
|
569 |
+
# Extract archive to temp dir
|
570 |
+
tempdir = tempfile.mkdtemp()
|
571 |
+
logger.info("extracting archive file {} to temp dir {}".format(
|
572 |
+
resolved_archive_file, tempdir))
|
573 |
+
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
574 |
+
archive.extractall(tempdir)
|
575 |
+
serialization_dir = tempdir
|
576 |
+
# Load config
|
577 |
+
config_file = os.path.join(serialization_dir, CONFIG_NAME)
|
578 |
+
config = BertConfig.from_json_file(config_file)
|
579 |
+
logger.info("Model config {}".format(config))
|
580 |
+
# Instantiate model.
|
581 |
+
model = cls(config, *inputs, **kwargs)
|
582 |
+
if state_dict is None and not from_tf:
|
583 |
+
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
|
584 |
+
state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
|
585 |
+
if tempdir:
|
586 |
+
# Clean up temp dir
|
587 |
+
shutil.rmtree(tempdir)
|
588 |
+
if from_tf:
|
589 |
+
# Directly load from a TensorFlow checkpoint
|
590 |
+
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
|
591 |
+
return load_tf_weights_in_bert(model, weights_path)
|
592 |
+
# Load from a PyTorch state_dict
|
593 |
+
old_keys = []
|
594 |
+
new_keys = []
|
595 |
+
for key in state_dict.keys():
|
596 |
+
new_key = None
|
597 |
+
if 'gamma' in key:
|
598 |
+
new_key = key.replace('gamma', 'weight')
|
599 |
+
if 'beta' in key:
|
600 |
+
new_key = key.replace('beta', 'bias')
|
601 |
+
if new_key:
|
602 |
+
old_keys.append(key)
|
603 |
+
new_keys.append(new_key)
|
604 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
605 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
606 |
+
|
607 |
+
missing_keys = []
|
608 |
+
unexpected_keys = []
|
609 |
+
error_msgs = []
|
610 |
+
# copy state_dict so _load_from_state_dict can modify it
|
611 |
+
metadata = getattr(state_dict, '_metadata', None)
|
612 |
+
state_dict = state_dict.copy()
|
613 |
+
if metadata is not None:
|
614 |
+
state_dict._metadata = metadata
|
615 |
+
|
616 |
+
def load(module, prefix=''):
|
617 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
618 |
+
module._load_from_state_dict(
|
619 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
620 |
+
for name, child in module._modules.items():
|
621 |
+
if child is not None:
|
622 |
+
load(child, prefix + name + '.')
|
623 |
+
start_prefix = ''
|
624 |
+
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
|
625 |
+
start_prefix = 'bert.'
|
626 |
+
load(model, prefix=start_prefix)
|
627 |
+
if len(missing_keys) > 0:
|
628 |
+
logger.info("Weights of {} not initialized from pretrained model: {}".format(
|
629 |
+
model.__class__.__name__, missing_keys))
|
630 |
+
if len(unexpected_keys) > 0:
|
631 |
+
logger.info("Weights from pretrained model not used in {}: {}".format(
|
632 |
+
model.__class__.__name__, unexpected_keys))
|
633 |
+
if len(error_msgs) > 0:
|
634 |
+
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
635 |
+
model.__class__.__name__, "\n\t".join(error_msgs)))
|
636 |
+
return model
|
637 |
+
|
638 |
+
|
639 |
+
class BertModel(BertPreTrainedModel):
|
640 |
+
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
|
641 |
+
|
642 |
+
Params:
|
643 |
+
config: a BertConfig class instance with the configuration to build a new model
|
644 |
+
|
645 |
+
Inputs:
|
646 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
647 |
+
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
648 |
+
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
649 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
650 |
+
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
651 |
+
a `sentence B` token (see BERT paper for more details).
|
652 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
653 |
+
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
654 |
+
input sequence length in the current batch. It's the mask that we typically use for attention when
|
655 |
+
a batch has varying length sentences.
|
656 |
+
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
657 |
+
|
658 |
+
Outputs: Tuple of (encoded_layers, pooled_output)
|
659 |
+
`encoded_layers`: controled by `output_all_encoded_layers` argument:
|
660 |
+
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
|
661 |
+
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
|
662 |
+
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
663 |
+
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
|
664 |
+
to the last attention block of shape [batch_size, sequence_length, hidden_size],
|
665 |
+
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
|
666 |
+
classifier pretrained on top of the hidden state associated to the first character of the
|
667 |
+
input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
|
668 |
+
|
669 |
+
Example usage:
|
670 |
+
```python
|
671 |
+
# Already been converted into WordPiece token ids
|
672 |
+
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
673 |
+
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
674 |
+
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
675 |
+
|
676 |
+
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
677 |
+
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
678 |
+
|
679 |
+
model = modeling.BertModel(config=config)
|
680 |
+
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
681 |
+
```
|
682 |
+
"""
|
683 |
+
def __init__(self, config):
|
684 |
+
super(BertModel, self).__init__(config)
|
685 |
+
self.embeddings = BertEmbeddings(config)
|
686 |
+
self.encoder = BertEncoder(config)
|
687 |
+
self.pooler = BertPooler(config)
|
688 |
+
self.apply(self.init_bert_weights)
|
689 |
+
|
690 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
|
691 |
+
if attention_mask is None:
|
692 |
+
attention_mask = torch.ones_like(input_ids)
|
693 |
+
if token_type_ids is None:
|
694 |
+
token_type_ids = torch.zeros_like(input_ids)
|
695 |
+
|
696 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
697 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
698 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
699 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
700 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
701 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
702 |
+
|
703 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
704 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
705 |
+
# positions we want to attend and -10000.0 for masked positions.
|
706 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
707 |
+
# effectively the same as removing these entirely.
|
708 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
709 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
710 |
+
|
711 |
+
embedding_output = self.embeddings(input_ids, token_type_ids)
|
712 |
+
encoded_layers = self.encoder(embedding_output,
|
713 |
+
extended_attention_mask,
|
714 |
+
output_all_encoded_layers=output_all_encoded_layers)
|
715 |
+
sequence_output = encoded_layers[-1]
|
716 |
+
pooled_output = self.pooler(sequence_output)
|
717 |
+
if not output_all_encoded_layers:
|
718 |
+
encoded_layers = encoded_layers[-1]
|
719 |
+
return encoded_layers, pooled_output
|
720 |
+
|
721 |
+
|
722 |
+
|
723 |
+
|
724 |
+
|
725 |
+
class BertForSequenceEncoder(BertPreTrainedModel):
|
726 |
+
"""BERT model for classification.
|
727 |
+
This module is composed of the BERT model with a linear layer on top of
|
728 |
+
the pooled output.
|
729 |
+
Params:
|
730 |
+
`config`: a BertConfig class instance with the configuration to build a new model.
|
731 |
+
`num_labels`: the number of classes for the classifier. Default = 2.
|
732 |
+
Inputs:
|
733 |
+
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
|
734 |
+
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
|
735 |
+
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
|
736 |
+
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
|
737 |
+
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
|
738 |
+
a `sentence B` token (see BERT paper for more details).
|
739 |
+
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
|
740 |
+
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
|
741 |
+
input sequence length in the current batch. It's the mask that we typically use for attention when
|
742 |
+
a batch has varying length sentences.
|
743 |
+
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
|
744 |
+
with indices selected in [0, ..., num_labels].
|
745 |
+
Outputs:
|
746 |
+
if `labels` is not `None`:
|
747 |
+
Outputs the CrossEntropy classification loss of the output with the labels.
|
748 |
+
if `labels` is `None`:
|
749 |
+
Outputs the classification logits of shape [batch_size, num_labels].
|
750 |
+
Example usage:
|
751 |
+
```python
|
752 |
+
# Already been converted into WordPiece token ids
|
753 |
+
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
754 |
+
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
755 |
+
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
756 |
+
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
|
757 |
+
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
|
758 |
+
num_labels = 2
|
759 |
+
model = BertForSequenceClassification(config, num_labels)
|
760 |
+
logits = model(input_ids, token_type_ids, input_mask)
|
761 |
+
```
|
762 |
+
"""
|
763 |
+
def __init__(self, config):
|
764 |
+
super(BertForSequenceEncoder, self).__init__(config)
|
765 |
+
self.bert = BertModel(config)
|
766 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
767 |
+
self.apply(self.init_bert_weights)
|
768 |
+
|
769 |
+
def forward(self, input_ids, attention_mask, token_type_ids):
|
770 |
+
output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
771 |
+
output = self.dropout(output)
|
772 |
+
pooled_output = self.dropout(pooled_output)
|
773 |
+
return output, pooled_output
|
774 |
+
|
775 |
+
|
utils/callbacks.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import torch
|
8 |
+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
9 |
+
from pytorch_lightning.utilities import rank_zero_only
|
10 |
+
|
11 |
+
from utils.utils_verbalisation_module import save_json
|
12 |
+
from pytorch_lightning.utilities import rank_zero_info
|
13 |
+
|
14 |
+
def count_trainable_parameters(model):
|
15 |
+
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
16 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
17 |
+
return params
|
18 |
+
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
class Seq2SeqLoggingCallback(pl.Callback):
|
25 |
+
def on_batch_end(self, trainer, pl_module):
|
26 |
+
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
|
27 |
+
pl_module.logger.log_metrics(lrs)
|
28 |
+
|
29 |
+
@rank_zero_only
|
30 |
+
def _write_logs(
|
31 |
+
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
|
32 |
+
) -> None:
|
33 |
+
logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****")
|
34 |
+
metrics = trainer.callback_metrics
|
35 |
+
#print(metrics.keys())
|
36 |
+
new_metrics = {}
|
37 |
+
ms = ["log", "progress_bar", "preds"]
|
38 |
+
for k, v in metrics.items():
|
39 |
+
ver = True
|
40 |
+
for m in ms:
|
41 |
+
if m in k:
|
42 |
+
ver = False
|
43 |
+
break
|
44 |
+
if ver:
|
45 |
+
new_metrics[k] = v
|
46 |
+
|
47 |
+
print(new_metrics)
|
48 |
+
trainer.logger.log_metrics(new_metrics)
|
49 |
+
# Log results
|
50 |
+
od = Path(pl_module.hparams.output_dir)
|
51 |
+
if type_path == "test":
|
52 |
+
results_file = od / "test_results.txt"
|
53 |
+
generations_file = od / "test_generations.txt"
|
54 |
+
else:
|
55 |
+
# this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
|
56 |
+
# If people want this it will be easy enough to add back.
|
57 |
+
results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
|
58 |
+
generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
|
59 |
+
results_file.parent.mkdir(exist_ok=True)
|
60 |
+
generations_file.parent.mkdir(exist_ok=True)
|
61 |
+
with open(results_file, "a+") as writer:
|
62 |
+
for key in sorted(metrics):
|
63 |
+
if key in ["log", "progress_bar", "preds"]:
|
64 |
+
continue
|
65 |
+
try:
|
66 |
+
val = metrics[key]
|
67 |
+
if isinstance(val, torch.Tensor):
|
68 |
+
val = val.item()
|
69 |
+
msg = f"{key}: {val:.6f}\n"
|
70 |
+
writer.write(msg)
|
71 |
+
except:
|
72 |
+
pass
|
73 |
+
|
74 |
+
if not save_generations:
|
75 |
+
return
|
76 |
+
|
77 |
+
if "preds" in metrics:
|
78 |
+
content = "\n".join(metrics["preds"])
|
79 |
+
generations_file.open("w+").write(content)
|
80 |
+
|
81 |
+
@rank_zero_only
|
82 |
+
def on_train_start(self, trainer, pl_module):
|
83 |
+
try:
|
84 |
+
npars = pl_module.model.model.num_parameters()
|
85 |
+
except AttributeError:
|
86 |
+
npars = pl_module.model.num_parameters()
|
87 |
+
|
88 |
+
n_trainable_pars = count_trainable_parameters(pl_module)
|
89 |
+
# mp stands for million parameters
|
90 |
+
trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6})
|
91 |
+
|
92 |
+
@rank_zero_only
|
93 |
+
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
94 |
+
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
95 |
+
return self._write_logs(trainer, pl_module, "test")
|
96 |
+
|
97 |
+
@rank_zero_only
|
98 |
+
def on_validation_end(self, trainer: pl.Trainer, pl_module):
|
99 |
+
save_json(pl_module.metrics, pl_module.metrics_save_path)
|
100 |
+
|
101 |
+
rank_zero_info("***** Validation results *****")
|
102 |
+
metrics = trainer.callback_metrics
|
103 |
+
# Log results
|
104 |
+
for key in sorted(metrics):
|
105 |
+
if key not in ["log", "progress_bar", "preds"]:
|
106 |
+
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
107 |
+
# Uncommenting this will save val generations
|
108 |
+
# return self._write_logs(trainer, pl_module, "valid")
|
109 |
+
|
110 |
+
|
111 |
+
def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=False):
|
112 |
+
"""Saves the best model by validation ROUGE2 score."""
|
113 |
+
if metric == "rouge2":
|
114 |
+
exp = "{val_avg_rouge2:.4f}-{step_count}"
|
115 |
+
elif metric == "bleu":
|
116 |
+
exp = "{val_avg_bleu:.4f}-{step_count}"
|
117 |
+
elif metric == "loss":
|
118 |
+
exp = "{val_avg_loss:.4f}-{step_count}"
|
119 |
+
else:
|
120 |
+
raise NotImplementedError(
|
121 |
+
f"seq2seq callbacks only support rouge2, bleu and loss, got {metric}, You can make your own by adding to this function."
|
122 |
+
)
|
123 |
+
|
124 |
+
checkpoint_callback = ModelCheckpoint(
|
125 |
+
filepath=os.path.join(output_dir, exp),
|
126 |
+
monitor=f"val_{metric}",
|
127 |
+
mode="min" if "loss" in metric else "max",
|
128 |
+
save_top_k=save_top_k,
|
129 |
+
period=0, # maybe save a checkpoint every time val is run, not just end of epoch.
|
130 |
+
)
|
131 |
+
return checkpoint_callback
|
132 |
+
|
133 |
+
|
134 |
+
def get_early_stopping_callback(metric, patience):
|
135 |
+
return EarlyStopping(
|
136 |
+
monitor=f"val_{metric}", # does this need avg?
|
137 |
+
mode="min" if "loss" in metric else "max",
|
138 |
+
patience=patience,
|
139 |
+
verbose=True,
|
140 |
+
)
|
utils/file_utils.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for working with the local dataset cache.
|
3 |
+
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
4 |
+
Copyright by the AllenNLP authors.
|
5 |
+
"""
|
6 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
7 |
+
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import shutil
|
12 |
+
import tempfile
|
13 |
+
from functools import wraps
|
14 |
+
from hashlib import sha256
|
15 |
+
import sys
|
16 |
+
from io import open
|
17 |
+
|
18 |
+
import boto3
|
19 |
+
import requests
|
20 |
+
from botocore.exceptions import ClientError
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
try:
|
24 |
+
from urllib.parse import urlparse
|
25 |
+
except ImportError:
|
26 |
+
from urlparse import urlparse
|
27 |
+
|
28 |
+
try:
|
29 |
+
from pathlib import Path
|
30 |
+
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
31 |
+
Path.home() / '.pytorch_pretrained_bert'))
|
32 |
+
except AttributeError:
|
33 |
+
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
34 |
+
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
37 |
+
|
38 |
+
|
39 |
+
def url_to_filename(url, etag=None):
|
40 |
+
"""
|
41 |
+
Convert `url` into a hashed filename in a repeatable way.
|
42 |
+
If `etag` is specified, append its hash to the url's, delimited
|
43 |
+
by a period.
|
44 |
+
"""
|
45 |
+
url_bytes = url.encode('utf-8')
|
46 |
+
url_hash = sha256(url_bytes)
|
47 |
+
filename = url_hash.hexdigest()
|
48 |
+
|
49 |
+
if etag:
|
50 |
+
etag_bytes = etag.encode('utf-8')
|
51 |
+
etag_hash = sha256(etag_bytes)
|
52 |
+
filename += '.' + etag_hash.hexdigest()
|
53 |
+
|
54 |
+
return filename
|
55 |
+
|
56 |
+
|
57 |
+
def filename_to_url(filename, cache_dir=None):
|
58 |
+
"""
|
59 |
+
Return the url and etag (which may be ``None``) stored for `filename`.
|
60 |
+
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
61 |
+
"""
|
62 |
+
if cache_dir is None:
|
63 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
64 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
65 |
+
cache_dir = str(cache_dir)
|
66 |
+
|
67 |
+
cache_path = os.path.join(cache_dir, filename)
|
68 |
+
if not os.path.exists(cache_path):
|
69 |
+
raise EnvironmentError("file {} not found".format(cache_path))
|
70 |
+
|
71 |
+
meta_path = cache_path + '.json'
|
72 |
+
if not os.path.exists(meta_path):
|
73 |
+
raise EnvironmentError("file {} not found".format(meta_path))
|
74 |
+
|
75 |
+
with open(meta_path, encoding="utf-8") as meta_file:
|
76 |
+
metadata = json.load(meta_file)
|
77 |
+
url = metadata['url']
|
78 |
+
etag = metadata['etag']
|
79 |
+
|
80 |
+
return url, etag
|
81 |
+
|
82 |
+
|
83 |
+
def cached_path(url_or_filename, cache_dir=None):
|
84 |
+
"""
|
85 |
+
Given something that might be a URL (or might be a local path),
|
86 |
+
determine which. If it's a URL, download the file and cache it, and
|
87 |
+
return the path to the cached file. If it's already a local path,
|
88 |
+
make sure the file exists and then return the path.
|
89 |
+
"""
|
90 |
+
if cache_dir is None:
|
91 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
92 |
+
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
93 |
+
url_or_filename = str(url_or_filename)
|
94 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
95 |
+
cache_dir = str(cache_dir)
|
96 |
+
|
97 |
+
parsed = urlparse(url_or_filename)
|
98 |
+
|
99 |
+
if parsed.scheme in ('http', 'https', 's3'):
|
100 |
+
# URL, so get it from the cache (downloading if necessary)
|
101 |
+
return get_from_cache(url_or_filename, cache_dir)
|
102 |
+
elif os.path.exists(url_or_filename):
|
103 |
+
# File, and it exists.
|
104 |
+
return url_or_filename
|
105 |
+
elif parsed.scheme == '':
|
106 |
+
# File, but it doesn't exist.
|
107 |
+
raise EnvironmentError("file {} not found".format(url_or_filename))
|
108 |
+
else:
|
109 |
+
# Something unknown
|
110 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
111 |
+
|
112 |
+
|
113 |
+
def split_s3_path(url):
|
114 |
+
"""Split a full s3 path into the bucket name and path."""
|
115 |
+
parsed = urlparse(url)
|
116 |
+
if not parsed.netloc or not parsed.path:
|
117 |
+
raise ValueError("bad s3 path {}".format(url))
|
118 |
+
bucket_name = parsed.netloc
|
119 |
+
s3_path = parsed.path
|
120 |
+
# Remove '/' at beginning of path.
|
121 |
+
if s3_path.startswith("/"):
|
122 |
+
s3_path = s3_path[1:]
|
123 |
+
return bucket_name, s3_path
|
124 |
+
|
125 |
+
|
126 |
+
def s3_request(func):
|
127 |
+
"""
|
128 |
+
Wrapper function for s3 requests in order to create more helpful error
|
129 |
+
messages.
|
130 |
+
"""
|
131 |
+
|
132 |
+
@wraps(func)
|
133 |
+
def wrapper(url, *args, **kwargs):
|
134 |
+
try:
|
135 |
+
return func(url, *args, **kwargs)
|
136 |
+
except ClientError as exc:
|
137 |
+
if int(exc.response["Error"]["Code"]) == 404:
|
138 |
+
raise EnvironmentError("file {} not found".format(url))
|
139 |
+
else:
|
140 |
+
raise
|
141 |
+
|
142 |
+
return wrapper
|
143 |
+
|
144 |
+
|
145 |
+
@s3_request
|
146 |
+
def s3_etag(url):
|
147 |
+
"""Check ETag on S3 object."""
|
148 |
+
s3_resource = boto3.resource("s3")
|
149 |
+
bucket_name, s3_path = split_s3_path(url)
|
150 |
+
s3_object = s3_resource.Object(bucket_name, s3_path)
|
151 |
+
return s3_object.e_tag
|
152 |
+
|
153 |
+
|
154 |
+
@s3_request
|
155 |
+
def s3_get(url, temp_file):
|
156 |
+
"""Pull a file directly from S3."""
|
157 |
+
s3_resource = boto3.resource("s3")
|
158 |
+
bucket_name, s3_path = split_s3_path(url)
|
159 |
+
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
160 |
+
|
161 |
+
|
162 |
+
def http_get(url, temp_file):
|
163 |
+
req = requests.get(url, stream=True)
|
164 |
+
content_length = req.headers.get('Content-Length')
|
165 |
+
total = int(content_length) if content_length is not None else None
|
166 |
+
progress = tqdm(unit="B", total=total)
|
167 |
+
for chunk in req.iter_content(chunk_size=1024):
|
168 |
+
if chunk: # filter out keep-alive new chunks
|
169 |
+
progress.update(len(chunk))
|
170 |
+
temp_file.write(chunk)
|
171 |
+
progress.close()
|
172 |
+
|
173 |
+
|
174 |
+
def get_from_cache(url, cache_dir=None):
|
175 |
+
"""
|
176 |
+
Given a URL, look for the corresponding dataset in the local cache.
|
177 |
+
If it's not there, download it. Then return the path to the cached file.
|
178 |
+
"""
|
179 |
+
if cache_dir is None:
|
180 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
181 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
182 |
+
cache_dir = str(cache_dir)
|
183 |
+
|
184 |
+
if not os.path.exists(cache_dir):
|
185 |
+
os.makedirs(cache_dir)
|
186 |
+
|
187 |
+
# Get eTag to add to filename, if it exists.
|
188 |
+
if url.startswith("s3://"):
|
189 |
+
etag = s3_etag(url)
|
190 |
+
else:
|
191 |
+
response = requests.head(url, allow_redirects=True)
|
192 |
+
if response.status_code != 200:
|
193 |
+
raise IOError("HEAD request failed for url {} with status code {}"
|
194 |
+
.format(url, response.status_code))
|
195 |
+
etag = response.headers.get("ETag")
|
196 |
+
|
197 |
+
filename = url_to_filename(url, etag)
|
198 |
+
|
199 |
+
# get cache path to put the file
|
200 |
+
cache_path = os.path.join(cache_dir, filename)
|
201 |
+
|
202 |
+
if not os.path.exists(cache_path):
|
203 |
+
# Download to temporary file, then copy to cache dir once finished.
|
204 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
205 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
206 |
+
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
207 |
+
|
208 |
+
# GET file object
|
209 |
+
if url.startswith("s3://"):
|
210 |
+
s3_get(url, temp_file)
|
211 |
+
else:
|
212 |
+
http_get(url, temp_file)
|
213 |
+
|
214 |
+
# we are copying the file before closing it, so flush to avoid truncation
|
215 |
+
temp_file.flush()
|
216 |
+
# shutil.copyfileobj() starts at the current position, so go to the start
|
217 |
+
temp_file.seek(0)
|
218 |
+
|
219 |
+
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
220 |
+
with open(cache_path, 'wb') as cache_file:
|
221 |
+
shutil.copyfileobj(temp_file, cache_file)
|
222 |
+
|
223 |
+
logger.info("creating metadata file for %s", cache_path)
|
224 |
+
meta = {'url': url, 'etag': etag}
|
225 |
+
meta_path = cache_path + '.json'
|
226 |
+
with open(meta_path, 'w', encoding="utf-8") as meta_file:
|
227 |
+
json.dump(meta, meta_file)
|
228 |
+
|
229 |
+
logger.info("removing temp file %s", temp_file.name)
|
230 |
+
|
231 |
+
return cache_path
|
232 |
+
|
233 |
+
|
234 |
+
def read_set_from_file(filename):
|
235 |
+
'''
|
236 |
+
Extract a de-duped collection (set) of text from a file.
|
237 |
+
Expected file format is one item per line.
|
238 |
+
'''
|
239 |
+
collection = set()
|
240 |
+
with open(filename, 'r', encoding='utf-8') as file_:
|
241 |
+
for line in file_:
|
242 |
+
collection.add(line.rstrip())
|
243 |
+
return collection
|
244 |
+
|
245 |
+
|
246 |
+
def get_file_extension(path, dot=True, lower=True):
|
247 |
+
ext = os.path.splitext(path)[1]
|
248 |
+
ext = ext if dot else ext[1:]
|
249 |
+
return ext.lower() if lower else ext
|
utils/finetune.py
ADDED
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
from collections import defaultdict
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Dict, List, Tuple
|
12 |
+
import pdb
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import pytorch_lightning as pl
|
16 |
+
import torch
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
|
19 |
+
from pytorch_lightning.utilities import rank_zero_info
|
20 |
+
|
21 |
+
from utils.callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
22 |
+
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
23 |
+
|
24 |
+
from transformers.models.bart.modeling_bart import shift_tokens_right
|
25 |
+
from utils.utils_verbalisation_module import (
|
26 |
+
ROUGE_KEYS,
|
27 |
+
LegacySeq2SeqDataset,
|
28 |
+
Seq2SeqDataset,
|
29 |
+
assert_all_frozen,
|
30 |
+
calculate_bleu,
|
31 |
+
calculate_rouge,
|
32 |
+
flatten_list,
|
33 |
+
freeze_embeds,
|
34 |
+
freeze_params,
|
35 |
+
label_smoothed_nll_loss,
|
36 |
+
lmap,
|
37 |
+
pickle_save,
|
38 |
+
save_json,
|
39 |
+
use_task_specific_params,
|
40 |
+
)
|
41 |
+
|
42 |
+
from utils.utils_graph2text import convert_text, eval_meteor, eval_bleu, eval_chrf, eval_meteor_test_webnlg, eval_chrf_test_webnlg
|
43 |
+
|
44 |
+
# need the parent dir module
|
45 |
+
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
46 |
+
from utils.lightning_base import BaseTransformer, add_generic_args, generic_train # noqa
|
47 |
+
|
48 |
+
|
49 |
+
logger = logging.getLogger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
class SummarizationModule(BaseTransformer):
|
53 |
+
mode = "summarization"
|
54 |
+
loss_names = ["loss"]
|
55 |
+
metric_names = ROUGE_KEYS
|
56 |
+
default_val_metric = "rouge2"
|
57 |
+
|
58 |
+
def __init__(self, hparams, **kwargs):
|
59 |
+
if hparams.sortish_sampler and hparams.gpus > 1:
|
60 |
+
hparams.replace_sampler_ddp = False
|
61 |
+
elif hparams.max_tokens_per_batch is not None:
|
62 |
+
if hparams.gpus > 1:
|
63 |
+
raise NotImplementedError("Dynamic Batch size does not work for multi-gpu training")
|
64 |
+
if hparams.sortish_sampler:
|
65 |
+
raise ValueError("--sortish_sampler and --max_tokens_per_batch may not be used simultaneously")
|
66 |
+
|
67 |
+
super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs)
|
68 |
+
#use_task_specific_params(self.model, "summarization")
|
69 |
+
|
70 |
+
self.metrics_save_path = Path('base') / "metrics.json"
|
71 |
+
self.hparams_save_path = Path('base') / "hparams.pkl"
|
72 |
+
pickle_save(self.hparams, self.hparams_save_path)
|
73 |
+
self.step_count = -2
|
74 |
+
self.metrics = defaultdict(list)
|
75 |
+
self.model_type = self.config.model_type
|
76 |
+
self.vocab_size = self.config.tgt_vocab_size if self.model_type == "fsmt" else self.config.vocab_size
|
77 |
+
|
78 |
+
if 't5' in hparams.model_name_or_path:
|
79 |
+
self.model.config.prefix = 'translate Graph to English: '
|
80 |
+
self.dataset_kwargs: dict = dict(
|
81 |
+
data_dir=self.hparams.data_dir,
|
82 |
+
max_source_length=self.hparams.max_source_length,
|
83 |
+
prefix=self.model.config.prefix or "",
|
84 |
+
)
|
85 |
+
n_observations_per_split = {
|
86 |
+
"train": self.hparams.n_train,
|
87 |
+
"val": self.hparams.n_val,
|
88 |
+
"test_seen": self.hparams.n_test,
|
89 |
+
"test_unseen": self.hparams.n_test,
|
90 |
+
"test_both": self.hparams.n_test,
|
91 |
+
}
|
92 |
+
self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
|
93 |
+
|
94 |
+
self.target_lens = {
|
95 |
+
"train": self.hparams.max_target_length,
|
96 |
+
"val": self.hparams.val_max_target_length,
|
97 |
+
"test_seen": self.hparams.test_max_target_length,
|
98 |
+
"test_unseen": self.hparams.test_max_target_length,
|
99 |
+
"test_both": self.hparams.test_max_target_length,
|
100 |
+
}
|
101 |
+
assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}"
|
102 |
+
assert self.target_lens["train"] <= self.target_lens["test_both"], f"target_lens: {self.target_lens}"
|
103 |
+
if self.hparams.freeze_embeds:
|
104 |
+
freeze_embeds(self.model)
|
105 |
+
if self.hparams.freeze_encoder:
|
106 |
+
freeze_params(self.model.get_encoder())
|
107 |
+
assert_all_frozen(self.model.get_encoder())
|
108 |
+
|
109 |
+
self.num_workers = hparams.num_workers
|
110 |
+
self.decoder_start_token_id = None # default to config
|
111 |
+
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
112 |
+
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
113 |
+
self.model.config.decoder_start_token_id = self.decoder_start_token_id
|
114 |
+
self.dataset_class = (
|
115 |
+
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
116 |
+
)
|
117 |
+
self.already_saved_batch = False
|
118 |
+
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
119 |
+
if self.hparams.eval_max_gen_length is not None:
|
120 |
+
self.eval_max_length = self.hparams.eval_max_gen_length
|
121 |
+
else:
|
122 |
+
self.eval_max_length = self.model.config.max_length
|
123 |
+
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
124 |
+
|
125 |
+
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
|
126 |
+
"""A debugging utility"""
|
127 |
+
|
128 |
+
readable_batch = {
|
129 |
+
k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
|
130 |
+
}
|
131 |
+
save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
|
132 |
+
|
133 |
+
tb = {}
|
134 |
+
for k, v in batch.items():
|
135 |
+
tb[k] = v.tolist()
|
136 |
+
|
137 |
+
save_json(tb, Path(self.output_dir) / "tok_batch.json")
|
138 |
+
|
139 |
+
self.already_saved_batch = True
|
140 |
+
return readable_batch
|
141 |
+
|
142 |
+
def forward(self, input_ids, **kwargs):
|
143 |
+
return self.model(input_ids, **kwargs)
|
144 |
+
|
145 |
+
def ids_to_clean_text(self, generated_ids: List[int]):
|
146 |
+
gen_text = self.tokenizer.batch_decode(
|
147 |
+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
148 |
+
)
|
149 |
+
return lmap(str.strip, gen_text)
|
150 |
+
|
151 |
+
def _step(self, batch: dict) -> Tuple:
|
152 |
+
pad_token_id = self.tokenizer.pad_token_id
|
153 |
+
src_ids, src_mask = batch["input_ids"], batch["attention_mask"]
|
154 |
+
if isinstance(self.model, T5ForConditionalGeneration):
|
155 |
+
tgt_ids = batch["labels"]
|
156 |
+
decoder_input_ids = self.model._shift_right(tgt_ids)
|
157 |
+
else:
|
158 |
+
#decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
|
159 |
+
y = batch["labels"]
|
160 |
+
decoder_input_ids = y[:, :-1].contiguous()
|
161 |
+
tgt_ids = y[:, 1:].clone()
|
162 |
+
if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero
|
163 |
+
batch["decoder_input_ids"] = decoder_input_ids
|
164 |
+
self.save_readable_batch(batch)
|
165 |
+
|
166 |
+
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
167 |
+
lm_logits = outputs[0]
|
168 |
+
if self.hparams.label_smoothing == 0:
|
169 |
+
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
|
170 |
+
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
|
171 |
+
|
172 |
+
assert lm_logits.shape[-1] == self.vocab_size
|
173 |
+
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1))
|
174 |
+
else:
|
175 |
+
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
|
176 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
177 |
+
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id
|
178 |
+
)
|
179 |
+
return (loss,)
|
180 |
+
|
181 |
+
@property
|
182 |
+
def pad(self) -> int:
|
183 |
+
return self.tokenizer.pad_token_id
|
184 |
+
|
185 |
+
def training_step(self, batch, batch_idx) -> Dict:
|
186 |
+
loss_tensors = self._step(batch)
|
187 |
+
|
188 |
+
logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
189 |
+
# tokens per batch
|
190 |
+
logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["labels"].ne(self.pad).sum()
|
191 |
+
logs["bs"] = batch["input_ids"].shape[0]
|
192 |
+
logs["src_pad_tok"] = batch["input_ids"].eq(self.pad).sum()
|
193 |
+
logs["src_pad_frac"] = batch["input_ids"].eq(self.pad).float().mean()
|
194 |
+
# TODO(SS): make a wandb summary metric for this
|
195 |
+
return {"loss": loss_tensors[0], "log": logs}
|
196 |
+
|
197 |
+
def validation_step(self, batch, batch_idx) -> Dict:
|
198 |
+
return self._generative_step(batch)
|
199 |
+
|
200 |
+
def validation_epoch_end(self, outputs, prefix="val") -> Dict:
|
201 |
+
|
202 |
+
self.step_count += 1
|
203 |
+
|
204 |
+
val_outputs_folder = "val_outputs"
|
205 |
+
os.system("mkdir -p " + os.path.join(self.hparams.output_dir, val_outputs_folder))
|
206 |
+
|
207 |
+
if prefix == "val":
|
208 |
+
output_test_predictions_file = os.path.join(self.hparams.output_dir, val_outputs_folder, "validation_predictions_" +
|
209 |
+
str(self.step_count) + ".txt")
|
210 |
+
output_test_targets_file = os.path.join(self.hparams.output_dir, val_outputs_folder, "validation_targets_" +
|
211 |
+
str(self.step_count) + ".txt")
|
212 |
+
# write predictions and targets for later rouge evaluation.
|
213 |
+
with open(output_test_predictions_file, "w") as p_writer, open(output_test_targets_file, "w") as t_writer:
|
214 |
+
for output_batch in outputs:
|
215 |
+
p_writer.writelines(convert_text(s) + "\n" for s in output_batch["preds"])
|
216 |
+
t_writer.writelines(convert_text(s) + "\n" for s in output_batch["target"])
|
217 |
+
p_writer.close()
|
218 |
+
t_writer.close()
|
219 |
+
|
220 |
+
bleu_info = eval_bleu(self.hparams.data_dir, output_test_predictions_file, 'val')
|
221 |
+
|
222 |
+
rank_zero_info("%s bleu_info: %s", self.step_count, bleu_info)
|
223 |
+
|
224 |
+
if bleu_info == -1:
|
225 |
+
bleu_info = float(bleu_info)
|
226 |
+
else:
|
227 |
+
bleu_info = float(bleu_info.split(",")[0].split("BLEU = ")[1])
|
228 |
+
|
229 |
+
losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
|
230 |
+
loss = losses["loss"]
|
231 |
+
generative_metrics = {
|
232 |
+
k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
233 |
+
}
|
234 |
+
|
235 |
+
generative_metrics['bleu'] = bleu_info
|
236 |
+
|
237 |
+
metric_val = (
|
238 |
+
generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[
|
239 |
+
self.val_metric]
|
240 |
+
)
|
241 |
+
metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
|
242 |
+
generative_metrics.update({k: v.item() for k, v in losses.items()})
|
243 |
+
losses.update(generative_metrics)
|
244 |
+
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
245 |
+
all_metrics["step_count"] = self.step_count
|
246 |
+
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
|
247 |
+
preds = flatten_list([x["preds"] for x in outputs])
|
248 |
+
|
249 |
+
return {
|
250 |
+
"bleu": bleu_info,
|
251 |
+
"log": all_metrics,
|
252 |
+
"preds": preds,
|
253 |
+
f"{prefix}_loss": loss,
|
254 |
+
f"{prefix}_{self.val_metric}": metric_tensor,
|
255 |
+
}
|
256 |
+
else:
|
257 |
+
|
258 |
+
data_logs = {}
|
259 |
+
for output in outputs:
|
260 |
+
|
261 |
+
dataset_idx = output[0]['dataloader_idx']
|
262 |
+
|
263 |
+
if dataset_idx == 0:
|
264 |
+
dataset_name = 'test_both'
|
265 |
+
elif dataset_idx == 1:
|
266 |
+
dataset_name = 'test_seen'
|
267 |
+
else:
|
268 |
+
dataset_name = 'test_unseen'
|
269 |
+
|
270 |
+
if output[0]['bleu'] == -1:
|
271 |
+
bleu_info = float(output[0]['bleu'])
|
272 |
+
else:
|
273 |
+
bleu_info = float(output[0]['bleu'].split(",")[0].split("BLEU = ")[1])
|
274 |
+
|
275 |
+
|
276 |
+
losses = {k: torch.stack([x[k] for x in output]).mean() for k in self.loss_names}
|
277 |
+
loss = losses["loss"]
|
278 |
+
generative_metrics = {
|
279 |
+
k: np.array([x[k] for x in output]).mean() for k in self.metric_names + ["gen_time", "gen_len"]
|
280 |
+
}
|
281 |
+
|
282 |
+
generative_metrics['bleu'] = bleu_info
|
283 |
+
|
284 |
+
metric_val = (
|
285 |
+
generative_metrics[self.val_metric] if self.val_metric in generative_metrics else losses[
|
286 |
+
self.val_metric]
|
287 |
+
)
|
288 |
+
metric_tensor: torch.FloatTensor = torch.tensor(metric_val).type_as(loss)
|
289 |
+
generative_metrics.update({k: v.item() for k, v in losses.items()})
|
290 |
+
losses.update(generative_metrics)
|
291 |
+
all_metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
|
292 |
+
all_metrics["step_count"] = self.step_count
|
293 |
+
self.metrics[prefix].append(all_metrics) # callback writes this to self.metrics_save_path
|
294 |
+
preds = flatten_list([x["preds"] for x in output])
|
295 |
+
|
296 |
+
data_logs.update({
|
297 |
+
"log" + "_" + dataset_name: all_metrics,
|
298 |
+
"preds" + "_" + dataset_name: preds,
|
299 |
+
f"{prefix}_loss" + "_" + dataset_name: loss,
|
300 |
+
f"{prefix}_{self.val_metric}" + "_" + dataset_name: metric_tensor,
|
301 |
+
})
|
302 |
+
return data_logs
|
303 |
+
|
304 |
+
|
305 |
+
#######
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
|
310 |
+
def calc_generative_metrics(self, preds, target) -> Dict:
|
311 |
+
return calculate_rouge(preds, target)
|
312 |
+
|
313 |
+
def _generative_step(self, batch: dict, batch_idx=None, dataloader_idx=None) -> dict:
|
314 |
+
t0 = time.time()
|
315 |
+
|
316 |
+
# parser.add_argument('--eval_max_gen_length', type=int, default=None, help='never generate more than n tokens')
|
317 |
+
generated_ids = self.model.generate(
|
318 |
+
batch["input_ids"],
|
319 |
+
attention_mask=batch["attention_mask"],
|
320 |
+
use_cache=True,
|
321 |
+
decoder_start_token_id=self.decoder_start_token_id,
|
322 |
+
num_beams=self.eval_beams,
|
323 |
+
max_length=self.eval_max_length,
|
324 |
+
length_penalty=1.0
|
325 |
+
)
|
326 |
+
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
|
327 |
+
preds: List[str] = self.ids_to_clean_text(generated_ids)
|
328 |
+
target: List[str] = self.ids_to_clean_text(batch["labels"])
|
329 |
+
loss_tensors = self._step(batch)
|
330 |
+
base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)}
|
331 |
+
rouge: Dict = self.calc_generative_metrics(preds, target)
|
332 |
+
summ_len = np.mean(lmap(len, generated_ids))
|
333 |
+
base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge)
|
334 |
+
|
335 |
+
if dataloader_idx is not None:
|
336 |
+
base_metrics.update(batch_idx=batch_idx, dataloader_idx=dataloader_idx)
|
337 |
+
return base_metrics
|
338 |
+
|
339 |
+
def test_step(self, batch, batch_idx, dataloader_idx):
|
340 |
+
return self._generative_step(batch, batch_idx, dataloader_idx)
|
341 |
+
|
342 |
+
def test_epoch_end(self, outputs_all_testsets):
|
343 |
+
|
344 |
+
val_outputs_folder = "val_outputs"
|
345 |
+
os.system("mkdir -p " + os.path.join(self.hparams.output_dir, val_outputs_folder))
|
346 |
+
|
347 |
+
for outputs in outputs_all_testsets:
|
348 |
+
dataset_idx = outputs[0]['dataloader_idx']
|
349 |
+
|
350 |
+
if dataset_idx == 0:
|
351 |
+
file_name = "test_both_predictions.txt"
|
352 |
+
file_name_tgt = "test_both_targets.txt"
|
353 |
+
dataset_name = 'test_both'
|
354 |
+
elif dataset_idx == 1:
|
355 |
+
file_name = "test_seen_predictions.txt"
|
356 |
+
file_name_tgt = "test_seen_targets.txt"
|
357 |
+
dataset_name = 'test_seen'
|
358 |
+
else:
|
359 |
+
file_name = "test_unseen_predictions.txt"
|
360 |
+
file_name_tgt = "test_unseen_targets.txt"
|
361 |
+
dataset_name = 'test_unseen'
|
362 |
+
|
363 |
+
file_name += '.debug'
|
364 |
+
file_name_tgt += '.debug'
|
365 |
+
|
366 |
+
output_test_predictions_file = os.path.join(self.hparams.output_dir, val_outputs_folder, file_name)
|
367 |
+
output_test_targets_file = os.path.join(self.hparams.output_dir, val_outputs_folder, file_name_tgt)
|
368 |
+
# write predictions and targets for later rouge evaluation.
|
369 |
+
with open(output_test_predictions_file, "w") as p_writer, open(output_test_targets_file, "w") as t_writer:
|
370 |
+
for output_batch in outputs:
|
371 |
+
|
372 |
+
p_writer.writelines(convert_text(s) + "\n" for s in output_batch["preds"])
|
373 |
+
t_writer.writelines(convert_text(s) + "\n" for s in output_batch["target"])
|
374 |
+
p_writer.close()
|
375 |
+
t_writer.close()
|
376 |
+
|
377 |
+
bleu_info = eval_bleu(self.hparams.data_dir, output_test_predictions_file, dataset_name)
|
378 |
+
meteor_info = eval_meteor_test_webnlg(self.hparams.data_dir, output_test_predictions_file, dataset_name)
|
379 |
+
chrf_info = eval_chrf_test_webnlg(self.hparams.data_dir, output_test_predictions_file, dataset_name)
|
380 |
+
|
381 |
+
rank_zero_info(" %s - bleu_info: %s", dataset_name, bleu_info)
|
382 |
+
rank_zero_info(" %s - meteor_info: %s", dataset_name, meteor_info)
|
383 |
+
rank_zero_info(" %s - chrf_info: %s", dataset_name, chrf_info)
|
384 |
+
|
385 |
+
outputs[0]['bleu'] = bleu_info
|
386 |
+
|
387 |
+
return self.validation_epoch_end(outputs_all_testsets, prefix="test")
|
388 |
+
|
389 |
+
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
390 |
+
n_obs = self.n_obs[type_path]
|
391 |
+
max_target_length = self.target_lens[type_path]
|
392 |
+
dataset = self.dataset_class(
|
393 |
+
self.tokenizer,
|
394 |
+
type_path=type_path,
|
395 |
+
n_obs=n_obs,
|
396 |
+
max_target_length=max_target_length,
|
397 |
+
**self.dataset_kwargs,
|
398 |
+
)
|
399 |
+
return dataset
|
400 |
+
|
401 |
+
def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
|
402 |
+
dataset = self.get_dataset(type_path)
|
403 |
+
|
404 |
+
if self.hparams.sortish_sampler and type_path != "test":
|
405 |
+
sampler = dataset.make_sortish_sampler(batch_size, distributed=self.hparams.gpus > 1)
|
406 |
+
return DataLoader(
|
407 |
+
dataset,
|
408 |
+
batch_size=batch_size,
|
409 |
+
collate_fn=dataset.collate_fn,
|
410 |
+
shuffle=False,
|
411 |
+
num_workers=self.num_workers,
|
412 |
+
sampler=sampler,
|
413 |
+
)
|
414 |
+
|
415 |
+
elif self.hparams.max_tokens_per_batch is not None and type_path != "test":
|
416 |
+
batch_sampler = dataset.make_dynamic_sampler(
|
417 |
+
self.hparams.max_tokens_per_batch, distributed=self.hparams.gpus > 1
|
418 |
+
)
|
419 |
+
return DataLoader(
|
420 |
+
dataset,
|
421 |
+
batch_sampler=batch_sampler,
|
422 |
+
collate_fn=dataset.collate_fn,
|
423 |
+
# shuffle=False,
|
424 |
+
num_workers=self.num_workers,
|
425 |
+
# batch_size=None,
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
return DataLoader(
|
429 |
+
dataset,
|
430 |
+
batch_size=batch_size,
|
431 |
+
collate_fn=dataset.collate_fn,
|
432 |
+
shuffle=shuffle,
|
433 |
+
num_workers=self.num_workers,
|
434 |
+
sampler=None,
|
435 |
+
)
|
436 |
+
|
437 |
+
def train_dataloader(self) -> DataLoader:
|
438 |
+
dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True)
|
439 |
+
return dataloader
|
440 |
+
|
441 |
+
def val_dataloader(self) -> DataLoader:
|
442 |
+
return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size)
|
443 |
+
|
444 |
+
def test_dataloader(self) -> List[DataLoader]:
|
445 |
+
test_dataloader = self.get_dataloader("test_both", batch_size=self.hparams.eval_batch_size)
|
446 |
+
test_seen_dataloader = self.get_dataloader("test_seen", batch_size=self.hparams.eval_batch_size)
|
447 |
+
test_unseen_dataloader = self.get_dataloader("test_unseen", batch_size=self.hparams.eval_batch_size)
|
448 |
+
|
449 |
+
return [test_dataloader, test_seen_dataloader, test_unseen_dataloader]
|
450 |
+
|
451 |
+
@staticmethod
|
452 |
+
def add_model_specific_args(parser, root_dir):
|
453 |
+
BaseTransformer.add_model_specific_args(parser, root_dir)
|
454 |
+
add_generic_args(parser, root_dir)
|
455 |
+
parser.add_argument(
|
456 |
+
"--max_source_length",
|
457 |
+
default=1024,
|
458 |
+
type=int,
|
459 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
460 |
+
"than this will be truncated, sequences shorter will be padded.",
|
461 |
+
)
|
462 |
+
parser.add_argument(
|
463 |
+
"--max_target_length",
|
464 |
+
default=56,
|
465 |
+
type=int,
|
466 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
467 |
+
"than this will be truncated, sequences shorter will be padded.",
|
468 |
+
)
|
469 |
+
parser.add_argument(
|
470 |
+
"--val_max_target_length",
|
471 |
+
default=142, # these defaults are optimized for CNNDM. For xsum, see README.md.
|
472 |
+
type=int,
|
473 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
474 |
+
"than this will be truncated, sequences shorter will be padded.",
|
475 |
+
)
|
476 |
+
parser.add_argument(
|
477 |
+
"--test_max_target_length",
|
478 |
+
default=142,
|
479 |
+
type=int,
|
480 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
481 |
+
"than this will be truncated, sequences shorter will be padded.",
|
482 |
+
)
|
483 |
+
parser.add_argument("--freeze_encoder", action="store_true")
|
484 |
+
parser.add_argument("--freeze_embeds", action="store_true")
|
485 |
+
parser.add_argument("--sortish_sampler", action="store_true", default=False)
|
486 |
+
parser.add_argument("--max_tokens_per_batch", type=int, default=None)
|
487 |
+
parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default")
|
488 |
+
parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
489 |
+
parser.add_argument("--n_val", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
490 |
+
parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.")
|
491 |
+
parser.add_argument(
|
492 |
+
"--task", type=str, default="summarization", required=False, help="# examples. -1 means use all."
|
493 |
+
)
|
494 |
+
parser.add_argument("--label_smoothing", type=float, default=0.0, required=False)
|
495 |
+
parser.add_argument("--src_lang", type=str, default="", required=False)
|
496 |
+
parser.add_argument("--tgt_lang", type=str, default="", required=False)
|
497 |
+
parser.add_argument("--eval_beams", type=int, default=None, required=False)
|
498 |
+
parser.add_argument("--checkpoint", type=str, default=None, required=False)
|
499 |
+
parser.add_argument(
|
500 |
+
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
|
501 |
+
)
|
502 |
+
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
|
503 |
+
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
|
504 |
+
parser.add_argument(
|
505 |
+
"--early_stopping_patience",
|
506 |
+
type=int,
|
507 |
+
default=-1,
|
508 |
+
required=False,
|
509 |
+
help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.",
|
510 |
+
)
|
511 |
+
|
512 |
+
return parser
|
513 |
+
|
514 |
+
|
515 |
+
class TranslationModule(SummarizationModule):
|
516 |
+
mode = "translation"
|
517 |
+
loss_names = ["loss"]
|
518 |
+
metric_names = ["bleu"]
|
519 |
+
default_val_metric = "bleu"
|
520 |
+
|
521 |
+
def __init__(self, hparams, **kwargs):
|
522 |
+
super().__init__(hparams, **kwargs)
|
523 |
+
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
524 |
+
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
525 |
+
|
526 |
+
def calc_generative_metrics(self, preds, target) -> dict:
|
527 |
+
return calculate_bleu(preds, target)
|
528 |
+
|
529 |
+
|
530 |
+
class Graph2TextModule(SummarizationModule):
|
531 |
+
mode = "graph2text"
|
532 |
+
loss_names = ["loss"]
|
533 |
+
metric_names = ["sacrebleu"]
|
534 |
+
default_val_metric = "bleu"
|
535 |
+
|
536 |
+
def __init__(self, hparams, **kwargs):
|
537 |
+
if type(hparams) == dict:
|
538 |
+
hparams = argparse.Namespace(**hparams)
|
539 |
+
print(f'Graph2Text hparams are: {hparams}')
|
540 |
+
super().__init__(hparams, **kwargs)
|
541 |
+
|
542 |
+
self.hparams.update(vars(hparams))
|
543 |
+
|
544 |
+
rank_zero_info("parameters %s", hparams)
|
545 |
+
|
546 |
+
def calc_generative_metrics(self, preds, target) -> dict:
|
547 |
+
return calculate_bleu(preds, target)
|
548 |
+
|
549 |
+
|
550 |
+
def main(args, model=None) -> SummarizationModule:
|
551 |
+
Path(args.output_dir).mkdir(exist_ok=True)
|
552 |
+
if len(os.listdir(args.output_dir)) > 3 and args.do_train:
|
553 |
+
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
554 |
+
if model is None:
|
555 |
+
if "summarization" in args.task:
|
556 |
+
model: SummarizationModule = SummarizationModule(args)
|
557 |
+
elif "translation" in args.task:
|
558 |
+
model: SummarizationModule = TranslationModule(args)
|
559 |
+
else:
|
560 |
+
model: SummarizationModule = Graph2TextModule(args)
|
561 |
+
dataset = Path(args.data_dir).name
|
562 |
+
if (
|
563 |
+
args.logger_name == "default"
|
564 |
+
or args.fast_dev_run
|
565 |
+
or str(args.output_dir).startswith("/tmp")
|
566 |
+
or str(args.output_dir).startswith("/var")
|
567 |
+
):
|
568 |
+
logger = True # don't pollute wandb logs unnecessarily
|
569 |
+
elif args.logger_name == "wandb":
|
570 |
+
from pytorch_lightning.loggers import WandbLogger
|
571 |
+
|
572 |
+
project = os.environ.get("WANDB_PROJECT", dataset)
|
573 |
+
logger = WandbLogger(name=model.output_dir.name, project=project)
|
574 |
+
|
575 |
+
elif args.logger_name == "wandb_shared":
|
576 |
+
from pytorch_lightning.loggers import WandbLogger
|
577 |
+
|
578 |
+
logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}")
|
579 |
+
|
580 |
+
if args.early_stopping_patience >= 0:
|
581 |
+
es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience)
|
582 |
+
else:
|
583 |
+
es_callback = False
|
584 |
+
|
585 |
+
lower_is_better = args.val_metric == "loss"
|
586 |
+
trainer: pl.Trainer = generic_train(
|
587 |
+
model,
|
588 |
+
args,
|
589 |
+
logging_callback=Seq2SeqLoggingCallback(),
|
590 |
+
checkpoint_callback=get_checkpoint_callback(
|
591 |
+
args.output_dir, model.val_metric, args.save_top_k, lower_is_better
|
592 |
+
),
|
593 |
+
early_stopping_callback=es_callback,
|
594 |
+
logger=logger,
|
595 |
+
)
|
596 |
+
pickle_save(model.hparams, model.output_dir / "hparams.pkl")
|
597 |
+
if not args.do_predict:
|
598 |
+
return model
|
599 |
+
|
600 |
+
model.hparams.test_checkpoint = ""
|
601 |
+
if not args.checkpoint:
|
602 |
+
checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True)))
|
603 |
+
else:
|
604 |
+
checkpoints = [args.checkpoint]
|
605 |
+
|
606 |
+
if checkpoints:
|
607 |
+
model.hparams.test_checkpoint = checkpoints[-1]
|
608 |
+
trainer.resume_from_checkpoint = checkpoints[-1]
|
609 |
+
|
610 |
+
if args.do_predict and not args.do_train:
|
611 |
+
|
612 |
+
checkpoint = checkpoints[-1]
|
613 |
+
print(checkpoint)
|
614 |
+
#trainer.test(ckpt_path=checkpoints[-1])
|
615 |
+
trainer.test(model, ckpt_path=checkpoint)
|
616 |
+
return model
|
617 |
+
|
618 |
+
|
619 |
+
trainer.logger.log_hyperparams(model.hparams)
|
620 |
+
|
621 |
+
# test() without a model tests using the best checkpoint automatically
|
622 |
+
trainer.test()
|
623 |
+
return model
|
624 |
+
|
625 |
+
|
626 |
+
if __name__ == "__main__":
|
627 |
+
parser = argparse.ArgumentParser()
|
628 |
+
parser = pl.Trainer.add_argparse_args(parser)
|
629 |
+
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
|
630 |
+
|
631 |
+
args = parser.parse_args()
|
632 |
+
|
633 |
+
main(args)
|
utils/lightning_base.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Dict
|
6 |
+
import sys
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
from pytorch_lightning.utilities import rank_zero_info
|
9 |
+
from pytorch_lightning.callbacks import LearningRateMonitor
|
10 |
+
|
11 |
+
from transformers import (
|
12 |
+
AdamW,
|
13 |
+
AutoConfig,
|
14 |
+
AutoModel,
|
15 |
+
AutoModelForPreTraining,
|
16 |
+
AutoModelForQuestionAnswering,
|
17 |
+
AutoModelForSeq2SeqLM,
|
18 |
+
AutoModelForSequenceClassification,
|
19 |
+
AutoModelForTokenClassification,
|
20 |
+
AutoModelWithLMHead,
|
21 |
+
AutoTokenizer,
|
22 |
+
PretrainedConfig,
|
23 |
+
PreTrainedTokenizer,
|
24 |
+
)
|
25 |
+
from transformers.optimization import (
|
26 |
+
Adafactor,
|
27 |
+
get_cosine_schedule_with_warmup,
|
28 |
+
get_cosine_with_hard_restarts_schedule_with_warmup,
|
29 |
+
get_linear_schedule_with_warmup,
|
30 |
+
get_polynomial_decay_schedule_with_warmup,
|
31 |
+
)
|
32 |
+
|
33 |
+
from tokenizers import AddedToken
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
MODEL_MODES = {
|
38 |
+
"base": AutoModel,
|
39 |
+
"sequence-classification": AutoModelForSequenceClassification,
|
40 |
+
"question-answering": AutoModelForQuestionAnswering,
|
41 |
+
"pretraining": AutoModelForPreTraining,
|
42 |
+
"token-classification": AutoModelForTokenClassification,
|
43 |
+
"language-modeling": AutoModelWithLMHead,
|
44 |
+
"summarization": AutoModelForSeq2SeqLM,
|
45 |
+
"translation": AutoModelForSeq2SeqLM,
|
46 |
+
"graph2text": AutoModelForSeq2SeqLM,
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
# update this and the import above to support new schedulers from transformers.optimization
|
51 |
+
arg_to_scheduler = {
|
52 |
+
"linear": get_linear_schedule_with_warmup,
|
53 |
+
"cosine": get_cosine_schedule_with_warmup,
|
54 |
+
"cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup,
|
55 |
+
"polynomial": get_polynomial_decay_schedule_with_warmup,
|
56 |
+
# '': get_constant_schedule, # not supported for now
|
57 |
+
# '': get_constant_schedule_with_warmup, # not supported for now
|
58 |
+
}
|
59 |
+
arg_to_scheduler_choices = sorted(arg_to_scheduler.keys())
|
60 |
+
arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}"
|
61 |
+
|
62 |
+
|
63 |
+
class BaseTransformer(pl.LightningModule):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
hparams: argparse.Namespace,
|
67 |
+
num_labels=None,
|
68 |
+
mode="base",
|
69 |
+
config=None,
|
70 |
+
tokenizer=None,
|
71 |
+
model=None,
|
72 |
+
**config_kwargs
|
73 |
+
):
|
74 |
+
"""Initialize a model, tokenizer and config."""
|
75 |
+
super().__init__()
|
76 |
+
# TODO: move to self.save_hyperparameters()
|
77 |
+
# self.save_hyperparameters()
|
78 |
+
# can also expand arguments into trainer signature for easier reading
|
79 |
+
self.save_hyperparameters(hparams)
|
80 |
+
self.step_count = -2
|
81 |
+
self.output_dir = Path(self.hparams.output_dir)
|
82 |
+
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
83 |
+
if config is None:
|
84 |
+
self.config = AutoConfig.from_pretrained(
|
85 |
+
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
86 |
+
**({"num_labels": num_labels} if num_labels is not None else {}),
|
87 |
+
cache_dir=cache_dir,
|
88 |
+
**config_kwargs,
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
self.config: PretrainedConfig = config
|
92 |
+
|
93 |
+
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
94 |
+
for p in extra_model_params:
|
95 |
+
if getattr(self.hparams, p, None):
|
96 |
+
assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute"
|
97 |
+
setattr(self.config, p, getattr(self.hparams, p))
|
98 |
+
|
99 |
+
if tokenizer is None:
|
100 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
101 |
+
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
102 |
+
cache_dir=cache_dir,
|
103 |
+
)
|
104 |
+
new_tokens = [
|
105 |
+
'<H>','<R>','<T>'
|
106 |
+
]
|
107 |
+
new_tokens_vocab = {}
|
108 |
+
new_tokens_vocab['additional_special_tokens'] = []
|
109 |
+
for idx, t in enumerate(new_tokens):
|
110 |
+
new_tokens_vocab['additional_special_tokens'].append(t)
|
111 |
+
num_added_toks = self.tokenizer.add_special_tokens(new_tokens_vocab)
|
112 |
+
rank_zero_info('We have added %s tokens', num_added_toks)
|
113 |
+
else:
|
114 |
+
self.tokenizer: PreTrainedTokenizer = tokenizer
|
115 |
+
self.model_type = MODEL_MODES[mode]
|
116 |
+
if model is None:
|
117 |
+
self.model = self.model_type.from_pretrained(
|
118 |
+
self.hparams.model_name_or_path,
|
119 |
+
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
120 |
+
config=self.config,
|
121 |
+
cache_dir=cache_dir,
|
122 |
+
)
|
123 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
124 |
+
else:
|
125 |
+
self.model = model
|
126 |
+
|
127 |
+
def load_hf_checkpoint(self, *args, **kwargs):
|
128 |
+
self.model = self.model_type.from_pretrained(*args, **kwargs)
|
129 |
+
|
130 |
+
def get_lr_scheduler(self):
|
131 |
+
get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler]
|
132 |
+
scheduler = get_schedule_func(
|
133 |
+
self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps
|
134 |
+
)
|
135 |
+
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
|
136 |
+
return scheduler
|
137 |
+
|
138 |
+
def configure_optimizers(self):
|
139 |
+
"""Prepare optimizer and schedule (linear warmup and decay)"""
|
140 |
+
model = self.model
|
141 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
142 |
+
optimizer_grouped_parameters = [
|
143 |
+
{
|
144 |
+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
145 |
+
"weight_decay": self.hparams.weight_decay,
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
149 |
+
"weight_decay": 0.0,
|
150 |
+
},
|
151 |
+
]
|
152 |
+
if self.hparams.adafactor:
|
153 |
+
optimizer = Adafactor(
|
154 |
+
optimizer_grouped_parameters, lr=self.hparams.learning_rate, scale_parameter=False, relative_step=False
|
155 |
+
)
|
156 |
+
|
157 |
+
else:
|
158 |
+
optimizer = AdamW(
|
159 |
+
optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon
|
160 |
+
)
|
161 |
+
self.opt = optimizer
|
162 |
+
|
163 |
+
scheduler = self.get_lr_scheduler()
|
164 |
+
|
165 |
+
return [optimizer], [scheduler]
|
166 |
+
|
167 |
+
|
168 |
+
def test_step(self, batch, batch_nb):
|
169 |
+
return self.validation_step(batch, batch_nb)
|
170 |
+
|
171 |
+
def test_epoch_end(self, outputs):
|
172 |
+
return self.validation_end(outputs)
|
173 |
+
|
174 |
+
@property
|
175 |
+
def total_steps(self) -> int:
|
176 |
+
# print('self.hparams.gpus', self.hparams.gpus)
|
177 |
+
# print('self.hparams.accumulate_grad_batches', self.hparams.accumulate_grad_batches)
|
178 |
+
# print('self.train_loader.dataset', self.train_loader.dataset)
|
179 |
+
# print('self.hparams.max_epochs', self.hparams.max_epochs)
|
180 |
+
# print('self.hparams.train_batch_size', self.hparams.train_batch_size)
|
181 |
+
# exit()
|
182 |
+
"""The number of total training steps that will be run. Used for lr scheduler purposes."""
|
183 |
+
num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores
|
184 |
+
effective_batch_size = self.hparams.train_batch_size * self.hparams.accumulate_grad_batches * num_devices
|
185 |
+
dataset_size = len(self.train_loader.dataset)
|
186 |
+
return (dataset_size / effective_batch_size) * self.hparams.max_epochs
|
187 |
+
|
188 |
+
def setup(self, mode):
|
189 |
+
#if mode == "fit":
|
190 |
+
self.train_loader = self.get_dataloader("train", self.hparams.train_batch_size, shuffle=True)
|
191 |
+
|
192 |
+
def get_dataloader(self, type_path, batch_size, shuffle=False):
|
193 |
+
raise NotImplementedError("You must implement this for your task")
|
194 |
+
|
195 |
+
def train_dataloader(self):
|
196 |
+
return self.train_loader
|
197 |
+
|
198 |
+
def val_dataloader(self):
|
199 |
+
return self.get_dataloader("dev", self.hparams.eval_batch_size, shuffle=False)
|
200 |
+
|
201 |
+
def test_dataloader(self):
|
202 |
+
return self.get_dataloader("test", self.hparams.eval_batch_size, shuffle=False)
|
203 |
+
|
204 |
+
def _feature_file(self, mode):
|
205 |
+
return os.path.join(
|
206 |
+
self.hparams.data_dir,
|
207 |
+
"cached_{}_{}_{}".format(
|
208 |
+
mode,
|
209 |
+
list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(),
|
210 |
+
str(self.hparams.max_seq_length),
|
211 |
+
),
|
212 |
+
)
|
213 |
+
|
214 |
+
def get_progress_bar_dict(self):
|
215 |
+
#metrics = self.trainer.callback_metrics
|
216 |
+
#print(self.trainer.lr_logger.lrs)
|
217 |
+
lrs = self.trainer.lr_logger.lrs['lr-AdamW/pg1'][-1]
|
218 |
+
running_train_loss = self.trainer.running_loss.mean()
|
219 |
+
avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN')
|
220 |
+
tqdm_dict = {"loss": "{:.3f}".format(avg_training_loss), "lr": lrs}
|
221 |
+
return tqdm_dict
|
222 |
+
|
223 |
+
@pl.utilities.rank_zero_only
|
224 |
+
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
225 |
+
save_path = self.output_dir.joinpath("best_tfmr")
|
226 |
+
self.model.config.save_step = self.step_count
|
227 |
+
self.model.save_pretrained(save_path)
|
228 |
+
self.tokenizer.save_pretrained(save_path)
|
229 |
+
|
230 |
+
@staticmethod
|
231 |
+
def add_model_specific_args(parser, root_dir):
|
232 |
+
parser.add_argument(
|
233 |
+
"--model_name_or_path",
|
234 |
+
default=None,
|
235 |
+
type=str,
|
236 |
+
required=True,
|
237 |
+
help="Path to pretrained model or model identifier from huggingface.co/models",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--tokenizer_name",
|
244 |
+
default=None,
|
245 |
+
type=str,
|
246 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
247 |
+
)
|
248 |
+
parser.add_argument(
|
249 |
+
"--cache_dir",
|
250 |
+
default="",
|
251 |
+
type=str,
|
252 |
+
help="Where do you want to store the pre-trained models downloaded from s3",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--encoder_layerdrop",
|
256 |
+
type=float,
|
257 |
+
help="Encoder layer dropout probability (Optional). Goes into model.config",
|
258 |
+
)
|
259 |
+
parser.add_argument(
|
260 |
+
"--decoder_layerdrop",
|
261 |
+
type=float,
|
262 |
+
help="Decoder layer dropout probability (Optional). Goes into model.config",
|
263 |
+
)
|
264 |
+
parser.add_argument(
|
265 |
+
"--dropout",
|
266 |
+
type=float,
|
267 |
+
help="Dropout probability (Optional). Goes into model.config",
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--attention_dropout",
|
271 |
+
type=float,
|
272 |
+
help="Attention dropout probability (Optional). Goes into model.config",
|
273 |
+
)
|
274 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
275 |
+
parser.add_argument(
|
276 |
+
"--lr_scheduler",
|
277 |
+
default="linear",
|
278 |
+
choices=arg_to_scheduler_choices,
|
279 |
+
metavar=arg_to_scheduler_metavar,
|
280 |
+
type=str,
|
281 |
+
help="Learning rate scheduler",
|
282 |
+
)
|
283 |
+
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
284 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
285 |
+
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
286 |
+
parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader")
|
287 |
+
parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int)
|
288 |
+
parser.add_argument("--train_batch_size", default=32, type=int)
|
289 |
+
parser.add_argument("--eval_batch_size", default=32, type=int)
|
290 |
+
parser.add_argument("--adafactor", action="store_true")
|
291 |
+
|
292 |
+
|
293 |
+
class LoggingCallback(pl.Callback):
|
294 |
+
def on_batch_end(self, trainer, pl_module):
|
295 |
+
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
|
296 |
+
lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())}
|
297 |
+
pl_module.logger.log_metrics(lrs)
|
298 |
+
|
299 |
+
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
300 |
+
rank_zero_info("***** Validation results *****")
|
301 |
+
metrics = trainer.callback_metrics
|
302 |
+
rank_zero_info(trainer.logger)
|
303 |
+
# Log results
|
304 |
+
for key in sorted(metrics):
|
305 |
+
if key not in ["log", "progress_bar"]:
|
306 |
+
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
307 |
+
|
308 |
+
def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
|
309 |
+
rank_zero_info("***** Test results *****")
|
310 |
+
metrics = trainer.callback_metrics
|
311 |
+
# Log and save results to file
|
312 |
+
output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
|
313 |
+
with open(output_test_results_file, "w") as writer:
|
314 |
+
for key in sorted(metrics):
|
315 |
+
if key not in ["log", "progress_bar"]:
|
316 |
+
rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
|
317 |
+
writer.write("{} = {}\n".format(key, str(metrics[key])))
|
318 |
+
|
319 |
+
|
320 |
+
def add_generic_args(parser, root_dir) -> None:
|
321 |
+
# TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser)
|
322 |
+
parser.add_argument(
|
323 |
+
"--output_dir",
|
324 |
+
default=None,
|
325 |
+
type=str,
|
326 |
+
required=True,
|
327 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
328 |
+
)
|
329 |
+
parser.add_argument(
|
330 |
+
"--fp16",
|
331 |
+
action="store_true",
|
332 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
333 |
+
)
|
334 |
+
|
335 |
+
parser.add_argument(
|
336 |
+
"--fp16_opt_level",
|
337 |
+
type=str,
|
338 |
+
default="O2",
|
339 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
340 |
+
"See details at https://nvidia.github.io/apex/amp.html",
|
341 |
+
)
|
342 |
+
parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int)
|
343 |
+
parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm")
|
344 |
+
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
345 |
+
parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.")
|
346 |
+
parser.add_argument(
|
347 |
+
"--gradient_accumulation_steps",
|
348 |
+
dest="accumulate_grad_batches",
|
349 |
+
type=int,
|
350 |
+
default=1,
|
351 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
352 |
+
)
|
353 |
+
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
354 |
+
parser.add_argument(
|
355 |
+
"--data_dir",
|
356 |
+
default=None,
|
357 |
+
type=str,
|
358 |
+
required=True,
|
359 |
+
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
|
360 |
+
)
|
361 |
+
|
362 |
+
|
363 |
+
def generic_train(
|
364 |
+
model: BaseTransformer,
|
365 |
+
args: argparse.Namespace,
|
366 |
+
early_stopping_callback=False,
|
367 |
+
logger=True, # can pass WandbLogger() here
|
368 |
+
extra_callbacks=[],
|
369 |
+
checkpoint_callback=None,
|
370 |
+
logging_callback=None,
|
371 |
+
**extra_train_kwargs
|
372 |
+
):
|
373 |
+
pl.seed_everything(args.seed)
|
374 |
+
|
375 |
+
# init model
|
376 |
+
odir = Path(model.hparams.output_dir)
|
377 |
+
odir.mkdir(exist_ok=True)
|
378 |
+
|
379 |
+
# add custom checkpoints
|
380 |
+
if checkpoint_callback is None:
|
381 |
+
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
382 |
+
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1
|
383 |
+
)
|
384 |
+
if logging_callback is None:
|
385 |
+
logging_callback = LoggingCallback()
|
386 |
+
|
387 |
+
train_params = {}
|
388 |
+
|
389 |
+
# TODO: remove with PyTorch 1.6 since pl uses native amp
|
390 |
+
if args.fp16:
|
391 |
+
train_params["precision"] = 16
|
392 |
+
train_params["amp_level"] = args.fp16_opt_level
|
393 |
+
|
394 |
+
if args.gpus > 1:
|
395 |
+
train_params["distributed_backend"] = "ddp"
|
396 |
+
|
397 |
+
train_params["accumulate_grad_batches"] = args.accumulate_grad_batches
|
398 |
+
|
399 |
+
lr_logger = LearningRateMonitor(logging_interval='step')
|
400 |
+
|
401 |
+
# deterministic=True,
|
402 |
+
trainer = pl.Trainer.from_argparse_args(
|
403 |
+
args,
|
404 |
+
weights_summary='full',
|
405 |
+
callbacks=[logging_callback, lr_logger],
|
406 |
+
logger=logger,
|
407 |
+
checkpoint_callback=checkpoint_callback,
|
408 |
+
early_stop_callback=early_stopping_callback,
|
409 |
+
num_sanity_val_steps=4,
|
410 |
+
**train_params,
|
411 |
+
)
|
412 |
+
|
413 |
+
trainer.lr_logger = lr_logger
|
414 |
+
|
415 |
+
if args.do_train:
|
416 |
+
trainer.fit(model)
|
417 |
+
|
418 |
+
return trainer
|
utils/sentence_retrieval_model.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from utils.bert_model import BertForSequenceEncoder
|
5 |
+
|
6 |
+
class sentence_retrieval_model(nn.Module):
|
7 |
+
def __init__(self, args):
|
8 |
+
super(sentence_retrieval_model, self).__init__()
|
9 |
+
self.pred_model = BertForSequenceEncoder.from_pretrained(args['bert_pretrain'])
|
10 |
+
self.bert_hidden_dim = args['bert_hidden_dim']
|
11 |
+
self.dropout = nn.Dropout(args['dropout'])
|
12 |
+
self.proj_match = nn.Linear(self.bert_hidden_dim, 1)
|
13 |
+
|
14 |
+
|
15 |
+
def forward(self, inp_tensor, msk_tensor, seg_tensor):
|
16 |
+
_, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor)
|
17 |
+
inputs = self.dropout(inputs)
|
18 |
+
score = self.proj_match(inputs).squeeze(-1)
|
19 |
+
score = torch.tanh(score)
|
20 |
+
return score
|
utils/sentence_retrieval_module.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import List, Tuple
|
3 |
+
import pathlib
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import BertTokenizer
|
7 |
+
|
8 |
+
from utils.sentence_retrieval_model import sentence_retrieval_model
|
9 |
+
|
10 |
+
|
11 |
+
THIS_DIR = pathlib.Path(__file__).parent.absolute()
|
12 |
+
ARGS = {
|
13 |
+
'batch_size': 32,
|
14 |
+
'bert_pretrain': 'base/bert_base',
|
15 |
+
'checkpoint': 'base/model.best.32.pt',
|
16 |
+
'dropout': 0.6,
|
17 |
+
'bert_hidden_dim': 768,
|
18 |
+
'max_len': 384,
|
19 |
+
'cuda': torch.cuda.is_available()
|
20 |
+
}
|
21 |
+
|
22 |
+
if not ARGS['cuda']:
|
23 |
+
print('CUDA NOT AVAILABLE')
|
24 |
+
|
25 |
+
|
26 |
+
def process_sent(sentence):
|
27 |
+
sentence = re.sub("LSB.*?RSB", "", sentence)
|
28 |
+
sentence = re.sub("LRB\s*?RRB", "", sentence)
|
29 |
+
sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
|
30 |
+
sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
|
31 |
+
sentence = re.sub("--", "-", sentence)
|
32 |
+
sentence = re.sub("``", '"', sentence)
|
33 |
+
sentence = re.sub("''", '"', sentence)
|
34 |
+
return sentence
|
35 |
+
|
36 |
+
class SentenceRetrievalModule():
|
37 |
+
|
38 |
+
def __init__(self, max_len=None):
|
39 |
+
|
40 |
+
if max_len:
|
41 |
+
ARGS['max_len'] = max_len
|
42 |
+
|
43 |
+
self.tokenizer = BertTokenizer.from_pretrained(ARGS['bert_pretrain'], do_lower_case=False)
|
44 |
+
self.model = sentence_retrieval_model(ARGS)
|
45 |
+
self.model.load_state_dict(torch.load(ARGS['checkpoint'], map_location=torch.device('cpu'))['model'])
|
46 |
+
if ARGS['cuda']:
|
47 |
+
self.model = self.model.cuda()
|
48 |
+
|
49 |
+
def score_sentence_pairs(self, inputs: List[Tuple[str]]):
|
50 |
+
inputs_processed = [(process_sent(input[0]), process_sent(input[1])) for input in inputs]
|
51 |
+
|
52 |
+
encodings = self.tokenizer(
|
53 |
+
inputs_processed,
|
54 |
+
padding='max_length',
|
55 |
+
truncation='longest_first',
|
56 |
+
max_length=ARGS['max_len'],
|
57 |
+
return_token_type_ids=True,
|
58 |
+
return_attention_mask=True,
|
59 |
+
return_tensors='pt',
|
60 |
+
)
|
61 |
+
|
62 |
+
inp = encodings['input_ids']
|
63 |
+
msk = encodings['attention_mask']
|
64 |
+
seg = encodings['token_type_ids']
|
65 |
+
|
66 |
+
if ARGS['cuda']:
|
67 |
+
inp = inp.cuda()
|
68 |
+
msk = msk.cuda()
|
69 |
+
seg = seg.cuda()
|
70 |
+
|
71 |
+
self.model.eval()
|
72 |
+
with torch.no_grad():
|
73 |
+
outputs = self.model(inp, msk, seg).tolist()
|
74 |
+
|
75 |
+
assert len(outputs) == len(inputs)
|
76 |
+
|
77 |
+
return outputs
|
utils/textual_entailment_module.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from pathlib import Path
|
5 |
+
import torch
|
6 |
+
import re
|
7 |
+
|
8 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
9 |
+
|
10 |
+
# Constants and paths
|
11 |
+
HOME = Path('/users/k2031554')
|
12 |
+
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
13 |
+
MAX_LEN = 512
|
14 |
+
CLASSES = ['SUPPORTS','REFUTES','NOT ENOUGH INFO']
|
15 |
+
METHODS = ['WEIGHTED_SUM', 'MALON']
|
16 |
+
|
17 |
+
def process_sent(sentence):
|
18 |
+
sentence = re.sub("LSB.*?RSB", "", sentence)
|
19 |
+
sentence = re.sub("LRB\s*?RRB", "", sentence)
|
20 |
+
sentence = re.sub("(\s*?)LRB((\s*?))", "\\1(\\2", sentence)
|
21 |
+
sentence = re.sub("(\s*?)RRB((\s*?))", "\\1)\\2", sentence)
|
22 |
+
sentence = re.sub("--", "-", sentence)
|
23 |
+
sentence = re.sub("``", '"', sentence)
|
24 |
+
sentence = re.sub("''", '"', sentence)
|
25 |
+
return sentence
|
26 |
+
|
27 |
+
class TextualEntailmentModule():
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
model_path = 'base/models/BERT_FEVER_v4_model_PBT',
|
32 |
+
tokenizer_path = 'base/models/BERT_FEVER_v4_tok_PBT'
|
33 |
+
):
|
34 |
+
self.tokenizer = BertTokenizer.from_pretrained(
|
35 |
+
tokenizer_path
|
36 |
+
)
|
37 |
+
self.model = BertForSequenceClassification.from_pretrained(
|
38 |
+
model_path
|
39 |
+
)
|
40 |
+
self.model.to(DEVICE)
|
41 |
+
|
42 |
+
#def get_pair_scores(self, claim, evidence):
|
43 |
+
#
|
44 |
+
# encodings = self.tokenizer(
|
45 |
+
# [claim, evidence],
|
46 |
+
# max_length= MAX_LEN,
|
47 |
+
# return_token_type_ids=False,
|
48 |
+
# padding='max_length',
|
49 |
+
# truncation=True,
|
50 |
+
# return_tensors='pt',
|
51 |
+
# ).to(DEVICE)
|
52 |
+
#
|
53 |
+
# self.model.eval()
|
54 |
+
# with torch.no_grad():
|
55 |
+
# probs = self.model(
|
56 |
+
# input_ids=encodings['input_ids'],
|
57 |
+
# attention_mask=encodings['attention_mask']
|
58 |
+
# )
|
59 |
+
#
|
60 |
+
# return torch.softmax(probs.logits,dim=1).cpu().numpy()
|
61 |
+
|
62 |
+
def get_batch_scores(self, claims, evidence):
|
63 |
+
|
64 |
+
inputs = list(zip(claims, evidence))
|
65 |
+
|
66 |
+
encodings = self.tokenizer(
|
67 |
+
inputs,
|
68 |
+
max_length= MAX_LEN,
|
69 |
+
return_token_type_ids=False,
|
70 |
+
padding='max_length',
|
71 |
+
truncation=True,
|
72 |
+
return_tensors='pt',
|
73 |
+
).to(DEVICE)
|
74 |
+
|
75 |
+
self.model.eval()
|
76 |
+
with torch.no_grad():
|
77 |
+
probs = self.model(
|
78 |
+
input_ids=encodings['input_ids'],
|
79 |
+
attention_mask=encodings['attention_mask']
|
80 |
+
)
|
81 |
+
|
82 |
+
return torch.softmax(probs.logits,dim=1).cpu().numpy()
|
83 |
+
|
84 |
+
def get_label_from_scores(self, scores):
|
85 |
+
return CLASSES[np.argmax(scores)]
|
86 |
+
|
87 |
+
def get_label_malon(self, score_set):
|
88 |
+
score_labels = [np.argmax(s) for s in score_set]
|
89 |
+
if 1 not in score_labels and 0 not in score_labels:
|
90 |
+
return CLASSES[2] #NOT ENOUGH INFO
|
91 |
+
elif 0 in score_labels:
|
92 |
+
return CLASSES[0] #SUPPORTS
|
93 |
+
elif 1 in score_labels:
|
94 |
+
return CLASSES[1] #REFUTES
|
utils/utils_graph2text.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
|
4 |
+
def convert_text(text):
|
5 |
+
#return text
|
6 |
+
text = text.lower()
|
7 |
+
text = ' '.join(re.split('(\W)', text))
|
8 |
+
text = ' '.join(text.split())
|
9 |
+
return text
|
10 |
+
|
11 |
+
def eval_meteor_test_webnlg(folder_data, pred_file, dataset):
|
12 |
+
|
13 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
14 |
+
folder_data_before = dir_path + "/../utils"
|
15 |
+
|
16 |
+
cmd_string = "java -jar " + folder_data_before + "/meteor-1.5.jar " + pred_file + " " \
|
17 |
+
+ folder_data + "/" + dataset + ".target_eval_meteor -l en -norm -r 3 > " + pred_file.replace("txt", "meteor")
|
18 |
+
|
19 |
+
os.system(cmd_string)
|
20 |
+
|
21 |
+
meteor_info = open(pred_file.replace("txt", "meteor"), 'r').readlines()[-1].strip()
|
22 |
+
|
23 |
+
return meteor_info
|
24 |
+
|
25 |
+
|
26 |
+
def eval_chrf_test_webnlg(folder_data, pred_file, dataset):
|
27 |
+
|
28 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
29 |
+
folder_data_before = dir_path + "/../utils"
|
30 |
+
|
31 |
+
cmd_string = "python " + folder_data_before + "/chrf++.py -H " + pred_file + " -R " \
|
32 |
+
+ folder_data + "/" + dataset + ".target_eval_crf > " + pred_file.replace("txt", "chrf")
|
33 |
+
|
34 |
+
os.system(cmd_string)
|
35 |
+
|
36 |
+
chrf_info_1 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[1].strip()
|
37 |
+
chrf_info_2 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[2].strip()
|
38 |
+
|
39 |
+
return chrf_info_1 + " " + chrf_info_2
|
40 |
+
|
41 |
+
def eval_bleu(folder_data, pred_file, dataset):
|
42 |
+
|
43 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
44 |
+
folder_data_before = dir_path + "/data/"
|
45 |
+
|
46 |
+
cmd_string = "perl " + folder_data_before + "/multi-bleu.perl -lc " + folder_data + "/" + dataset + ".target_eval " \
|
47 |
+
+ folder_data + "/" + dataset + ".target2_eval " + folder_data + "/" + dataset + ".target3_eval < " \
|
48 |
+
+ pred_file + " > " + pred_file.replace("txt", "bleu")
|
49 |
+
|
50 |
+
os.system(cmd_string)
|
51 |
+
|
52 |
+
try:
|
53 |
+
bleu_info = open(pred_file.replace("txt", "bleu"), 'r').readlines()[0].strip()
|
54 |
+
except:
|
55 |
+
bleu_info = -1
|
56 |
+
|
57 |
+
return bleu_info
|
58 |
+
|
59 |
+
|
60 |
+
def eval_bleu_sents_tok(pred_file, folder_data, dataset):
|
61 |
+
|
62 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
63 |
+
folder_data_before = dir_path + "/../utils"
|
64 |
+
|
65 |
+
cmd_string = "perl " + folder_data_before + "/tokenizer.perl -threads 4 -no-escape < " + pred_file + " > " +\
|
66 |
+
pred_file + "_tok"
|
67 |
+
os.system(cmd_string)
|
68 |
+
|
69 |
+
cmd_string = "perl " + folder_data_before + "/multi-bleu.perl -lc " + folder_data + "/" + dataset + ".target.tok"\
|
70 |
+
+ " < " + pred_file + "_tok" + " > " + pred_file.replace("txt", "bleu_data")
|
71 |
+
os.system(cmd_string)
|
72 |
+
|
73 |
+
try:
|
74 |
+
bleu_info_data = open(pred_file.replace("txt", "bleu_data"), 'r').readlines()[0].strip()
|
75 |
+
except:
|
76 |
+
bleu_info_data = 'no data'
|
77 |
+
|
78 |
+
return bleu_info_data
|
79 |
+
|
80 |
+
|
81 |
+
def eval_meteor(ref_file, pred_file):
|
82 |
+
|
83 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
84 |
+
folder_data_before = dir_path + "/../utils"
|
85 |
+
|
86 |
+
cmd_string = "java -jar " + folder_data_before + "/meteor-1.5.jar " + pred_file + " " \
|
87 |
+
+ ref_file + " > " + pred_file.replace("txt", "meteor")
|
88 |
+
|
89 |
+
os.system(cmd_string)
|
90 |
+
|
91 |
+
meteor_info = open(pred_file.replace("txt", "meteor"), 'r').readlines()[-1].strip()
|
92 |
+
|
93 |
+
return meteor_info
|
94 |
+
|
95 |
+
|
96 |
+
def eval_chrf(ref_file, pred_file):
|
97 |
+
|
98 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
99 |
+
folder_data_before = dir_path + "/../utils"
|
100 |
+
|
101 |
+
cmd_string = "python " + folder_data_before + "/chrf++.py -H " + pred_file + " -R " \
|
102 |
+
+ ref_file + " > " + pred_file.replace("txt", "chrf")
|
103 |
+
|
104 |
+
os.system(cmd_string)
|
105 |
+
|
106 |
+
try:
|
107 |
+
chrf_info_1 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[1].strip()
|
108 |
+
chrf_info_2 = open(pred_file.replace("txt", "chrf"), 'r').readlines()[2].strip()
|
109 |
+
chrf_data = chrf_info_1 + " " + chrf_info_2
|
110 |
+
except:
|
111 |
+
chrf_data = "no data"
|
112 |
+
|
113 |
+
|
114 |
+
return chrf_data
|
utils/utils_verbalisation_module.py
ADDED
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import json
|
3 |
+
import linecache
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import pickle
|
7 |
+
import socket
|
8 |
+
from logging import getLogger
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.distributed as dist
|
15 |
+
from rouge_score import rouge_scorer, scoring
|
16 |
+
from sacrebleu import corpus_bleu
|
17 |
+
from torch import nn
|
18 |
+
from torch.utils.data import Dataset, Sampler
|
19 |
+
|
20 |
+
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
21 |
+
from transformers.file_utils import cached_property
|
22 |
+
from transformers.models.bart.modeling_bart import shift_tokens_right
|
23 |
+
from utils.utils_graph2text import convert_text, eval_bleu
|
24 |
+
from pytorch_lightning.utilities import rank_zero_info
|
25 |
+
import pdb
|
26 |
+
|
27 |
+
|
28 |
+
try:
|
29 |
+
from fairseq.data.data_utils import batch_by_size
|
30 |
+
|
31 |
+
FAIRSEQ_AVAILABLE = True
|
32 |
+
except (ImportError, ModuleNotFoundError):
|
33 |
+
FAIRSEQ_AVAILABLE = False
|
34 |
+
|
35 |
+
|
36 |
+
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
|
37 |
+
"""From fairseq"""
|
38 |
+
if target.dim() == lprobs.dim() - 1:
|
39 |
+
target = target.unsqueeze(-1)
|
40 |
+
nll_loss = -lprobs.gather(dim=-1, index=target)
|
41 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
42 |
+
if ignore_index is not None:
|
43 |
+
pad_mask = target.eq(ignore_index)
|
44 |
+
nll_loss.masked_fill_(pad_mask, 0.0)
|
45 |
+
smooth_loss.masked_fill_(pad_mask, 0.0)
|
46 |
+
else:
|
47 |
+
nll_loss = nll_loss.squeeze(-1)
|
48 |
+
smooth_loss = smooth_loss.squeeze(-1)
|
49 |
+
|
50 |
+
nll_loss = nll_loss.sum() # mean()? Scared to break other math.
|
51 |
+
smooth_loss = smooth_loss.sum()
|
52 |
+
eps_i = epsilon / lprobs.size(-1)
|
53 |
+
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
|
54 |
+
return loss, nll_loss
|
55 |
+
|
56 |
+
|
57 |
+
def lmap(f: Callable, x: Iterable) -> List:
|
58 |
+
"""list(map(f, x))"""
|
59 |
+
return list(map(f, x))
|
60 |
+
|
61 |
+
|
62 |
+
def calculate_bleu(output_lns, refs_lns) -> dict:
|
63 |
+
"""Uses sacrebleu's corpus_bleu implementation."""
|
64 |
+
return {"sacrebleu": round(corpus_bleu(output_lns, [refs_lns]).score, 4)}
|
65 |
+
|
66 |
+
|
67 |
+
def build_compute_metrics_fn(task_name: str, tokenizer: PreTrainedTokenizer) -> Callable[[EvalPrediction], Dict]:
|
68 |
+
def non_pad_len(tokens: np.ndarray) -> int:
|
69 |
+
return np.count_nonzero(tokens != tokenizer.pad_token_id)
|
70 |
+
|
71 |
+
def decode_pred(pred: EvalPrediction) -> Tuple[List[str], List[str]]:
|
72 |
+
pred_str = tokenizer.batch_decode(pred.predictions, skip_special_tokens=True)
|
73 |
+
label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
|
74 |
+
pred_str = lmap(str.strip, pred_str)
|
75 |
+
label_str = lmap(str.strip, label_str)
|
76 |
+
return pred_str, label_str
|
77 |
+
|
78 |
+
def summarization_metrics(pred: EvalPrediction) -> Dict:
|
79 |
+
pred_str, label_str = decode_pred(pred)
|
80 |
+
rouge: Dict = calculate_rouge(pred_str, label_str)
|
81 |
+
summ_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
82 |
+
rouge.update({"gen_len": summ_len})
|
83 |
+
return rouge
|
84 |
+
|
85 |
+
def translation_metrics(pred: EvalPrediction) -> Dict:
|
86 |
+
pred_str, label_str = decode_pred(pred)
|
87 |
+
bleu: Dict = calculate_bleu(pred_str, label_str)
|
88 |
+
gen_len = np.round(np.mean(lmap(non_pad_len, pred.predictions)), 1)
|
89 |
+
bleu.update({"gen_len": gen_len})
|
90 |
+
return bleu
|
91 |
+
|
92 |
+
compute_metrics_fn = summarization_metrics if "summarization" in task_name else translation_metrics
|
93 |
+
return compute_metrics_fn
|
94 |
+
|
95 |
+
|
96 |
+
def trim_batch(
|
97 |
+
input_ids,
|
98 |
+
pad_token_id,
|
99 |
+
attention_mask=None,
|
100 |
+
):
|
101 |
+
"""Remove columns that are populated exclusively by pad_token_id"""
|
102 |
+
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
103 |
+
if attention_mask is None:
|
104 |
+
return input_ids[:, keep_column_mask]
|
105 |
+
else:
|
106 |
+
return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
|
107 |
+
|
108 |
+
|
109 |
+
class AbstractSeq2SeqDataset(Dataset):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
tokenizer,
|
113 |
+
data_dir,
|
114 |
+
max_source_length,
|
115 |
+
max_target_length,
|
116 |
+
type_path="train",
|
117 |
+
n_obs=None,
|
118 |
+
prefix="",
|
119 |
+
**dataset_kwargs
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
123 |
+
self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
|
124 |
+
self.len_file = Path(data_dir).joinpath(type_path + ".len")
|
125 |
+
if os.path.exists(self.len_file):
|
126 |
+
self.src_lens = pickle_load(self.len_file)
|
127 |
+
self.used_char_len = False
|
128 |
+
else:
|
129 |
+
self.src_lens = self.get_char_lens(self.src_file)
|
130 |
+
self.used_char_len = True
|
131 |
+
self.max_source_length = max_source_length
|
132 |
+
self.max_target_length = max_target_length
|
133 |
+
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
|
134 |
+
self.tokenizer = tokenizer
|
135 |
+
self.prefix = prefix if prefix is not None else ""
|
136 |
+
|
137 |
+
if n_obs is not None:
|
138 |
+
self.src_lens = self.src_lens[:n_obs]
|
139 |
+
self.pad_token_id = self.tokenizer.pad_token_id
|
140 |
+
self.dataset_kwargs = dataset_kwargs
|
141 |
+
dataset_kwargs.update({"add_prefix_space": True} if isinstance(self.tokenizer, BartTokenizer) else {})
|
142 |
+
|
143 |
+
def __len__(self):
|
144 |
+
return len(self.src_lens)
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def get_char_lens(data_file):
|
148 |
+
return [len(x) for x in Path(data_file).open().readlines()]
|
149 |
+
|
150 |
+
@cached_property
|
151 |
+
def tgt_lens(self):
|
152 |
+
"""Length in characters of target documents"""
|
153 |
+
return self.get_char_lens(self.tgt_file)
|
154 |
+
|
155 |
+
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
|
156 |
+
if distributed:
|
157 |
+
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
|
158 |
+
else:
|
159 |
+
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
|
160 |
+
|
161 |
+
def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
|
162 |
+
assert FAIRSEQ_AVAILABLE, "Dynamic batch size requires `pip install fairseq`"
|
163 |
+
assert not self.used_char_len, "You must call python make_len_file.py before calling make_dynamic_sampler"
|
164 |
+
sorted_indices = list(self.make_sortish_sampler(1024, shuffle=False))
|
165 |
+
|
166 |
+
def num_tokens_in_example(i):
|
167 |
+
return min(self.src_lens[i], self.max_target_length)
|
168 |
+
|
169 |
+
# call fairseq cython function
|
170 |
+
batch_sampler: List[List[int]] = batch_by_size(
|
171 |
+
sorted_indices,
|
172 |
+
num_tokens_fn=num_tokens_in_example,
|
173 |
+
max_tokens=max_tokens_per_batch,
|
174 |
+
required_batch_size_multiple=64,
|
175 |
+
)
|
176 |
+
shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
|
177 |
+
# move the largest batch to the front to OOM quickly (uses an approximation for padding)
|
178 |
+
approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
|
179 |
+
largest_batch_idx = np.argmax(approximate_toks_per_batch)
|
180 |
+
shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
|
181 |
+
shuffled_batches[largest_batch_idx],
|
182 |
+
shuffled_batches[0],
|
183 |
+
)
|
184 |
+
return shuffled_batches
|
185 |
+
|
186 |
+
def __getitem__(self, item):
|
187 |
+
raise NotImplementedError("You must implement this")
|
188 |
+
|
189 |
+
def collate_fn(self, batch):
|
190 |
+
raise NotImplementedError("You must implement this")
|
191 |
+
|
192 |
+
|
193 |
+
class LegacySeq2SeqDataset(AbstractSeq2SeqDataset):
|
194 |
+
def __getitem__(self, index) -> Dict[str, torch.Tensor]:
|
195 |
+
"""Call tokenizer on src and tgt_lines"""
|
196 |
+
|
197 |
+
|
198 |
+
index = index + 1 # linecache starts at 1
|
199 |
+
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
200 |
+
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
201 |
+
assert source_line, f"empty source line for index {index}"
|
202 |
+
assert tgt_line, f"empty tgt line for index {index}"
|
203 |
+
source_inputs = self.encode_line(self.tokenizer, source_line, self.max_source_length)
|
204 |
+
target_inputs = self.encode_line(self.tokenizer, tgt_line, self.max_target_length)
|
205 |
+
|
206 |
+
source_ids = source_inputs["input_ids"].squeeze()
|
207 |
+
target_ids = target_inputs["input_ids"].squeeze()
|
208 |
+
src_mask = source_inputs["attention_mask"].squeeze()
|
209 |
+
return {
|
210 |
+
"input_ids": source_ids,
|
211 |
+
"attention_mask": src_mask,
|
212 |
+
"labels": target_ids,
|
213 |
+
}
|
214 |
+
|
215 |
+
def encode_line(self, tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"):
|
216 |
+
"""Only used by LegacyDataset"""
|
217 |
+
return tokenizer(
|
218 |
+
[line],
|
219 |
+
max_length=max_length,
|
220 |
+
padding="max_length" if pad_to_max_length else None,
|
221 |
+
truncation=True,
|
222 |
+
return_tensors=return_tensors,
|
223 |
+
**self.dataset_kwargs,
|
224 |
+
)
|
225 |
+
|
226 |
+
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
227 |
+
input_ids = torch.stack([x["input_ids"] for x in batch])
|
228 |
+
masks = torch.stack([x["attention_mask"] for x in batch])
|
229 |
+
target_ids = torch.stack([x["labels"] for x in batch])
|
230 |
+
pad_token_id = self.pad_token_id
|
231 |
+
y = trim_batch(target_ids, pad_token_id)
|
232 |
+
source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)
|
233 |
+
batch = {
|
234 |
+
"input_ids": source_ids,
|
235 |
+
"attention_mask": source_mask,
|
236 |
+
"labels": y,
|
237 |
+
}
|
238 |
+
return batch
|
239 |
+
|
240 |
+
|
241 |
+
class Seq2SeqDataset(AbstractSeq2SeqDataset):
|
242 |
+
"""A dataset that calls prepare_seq2seq_batch."""
|
243 |
+
|
244 |
+
def __getitem__(self, index) -> Dict[str, str]:
|
245 |
+
|
246 |
+
#print(self.dataset_kwargs['model_t'])
|
247 |
+
# if 't5' in self.dataset_kwargs['model_t']:
|
248 |
+
# self.prefix = 'translate Graph to English: '
|
249 |
+
# print('aac')
|
250 |
+
# exit()
|
251 |
+
|
252 |
+
index = index + 1 # linecache starts at 1
|
253 |
+
source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n")
|
254 |
+
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
|
255 |
+
assert source_line, f"empty source line for index {index}"
|
256 |
+
assert tgt_line, f"empty tgt line for index {index}"
|
257 |
+
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
|
258 |
+
|
259 |
+
def collate_fn(self, batch):
|
260 |
+
"""Call prepare_seq2seq_batch."""
|
261 |
+
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
|
262 |
+
[x["src_texts"] for x in batch],
|
263 |
+
tgt_texts=[x["tgt_texts"] for x in batch],
|
264 |
+
max_length=self.max_source_length,
|
265 |
+
max_target_length=self.max_target_length,
|
266 |
+
return_tensors="pt",
|
267 |
+
**self.dataset_kwargs,
|
268 |
+
).data
|
269 |
+
#lens = (batch_encoding['attention_mask'] == 1.).sum(dim=1).tolist()
|
270 |
+
|
271 |
+
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
|
272 |
+
|
273 |
+
return batch_encoding
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
class Seq2SeqDataCollator:
|
278 |
+
def __init__(self, tokenizer, data_args, tpu_num_cores=None):
|
279 |
+
self.tokenizer = tokenizer
|
280 |
+
self.pad_token_id = tokenizer.pad_token_id
|
281 |
+
assert (
|
282 |
+
self.pad_token_id is not None
|
283 |
+
), f"pad_token_id is not defined for ({self.tokenizer.__class__.__name__}), it must be defined."
|
284 |
+
self.data_args = data_args
|
285 |
+
self.tpu_num_cores = tpu_num_cores
|
286 |
+
self.dataset_kwargs = {"add_prefix_space": isinstance(tokenizer, BartTokenizer)}
|
287 |
+
if data_args.src_lang is not None:
|
288 |
+
self.dataset_kwargs["src_lang"] = data_args.src_lang
|
289 |
+
if data_args.tgt_lang is not None:
|
290 |
+
self.dataset_kwargs["tgt_lang"] = data_args.tgt_lang
|
291 |
+
|
292 |
+
def __call__(self, batch) -> Dict[str, torch.Tensor]:
|
293 |
+
if hasattr(self.tokenizer, "prepare_seq2seq_batch"):
|
294 |
+
batch = self._encode(batch)
|
295 |
+
input_ids, attention_mask, labels = (
|
296 |
+
batch["input_ids"],
|
297 |
+
batch["attention_mask"],
|
298 |
+
batch["labels"],
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
input_ids = torch.stack([x["input_ids"] for x in batch])
|
302 |
+
attention_mask = torch.stack([x["attention_mask"] for x in batch])
|
303 |
+
labels = torch.stack([x["labels"] for x in batch])
|
304 |
+
|
305 |
+
labels = trim_batch(labels, self.pad_token_id)
|
306 |
+
input_ids, attention_mask = trim_batch(input_ids, self.pad_token_id, attention_mask=attention_mask)
|
307 |
+
|
308 |
+
if isinstance(self.tokenizer, T5Tokenizer):
|
309 |
+
decoder_input_ids = self._shift_right_t5(labels)
|
310 |
+
else:
|
311 |
+
decoder_input_ids = shift_tokens_right(labels, self.pad_token_id)
|
312 |
+
|
313 |
+
batch = {
|
314 |
+
"input_ids": input_ids,
|
315 |
+
"attention_mask": attention_mask,
|
316 |
+
"decoder_input_ids": decoder_input_ids,
|
317 |
+
"labels": labels,
|
318 |
+
}
|
319 |
+
return batch
|
320 |
+
|
321 |
+
def _shift_right_t5(self, input_ids):
|
322 |
+
# shift inputs to the right
|
323 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
324 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
325 |
+
shifted_input_ids[..., 0] = self.pad_token_id
|
326 |
+
return shifted_input_ids
|
327 |
+
|
328 |
+
def _encode(self, batch) -> Dict[str, torch.Tensor]:
|
329 |
+
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
|
330 |
+
[x["src_texts"] for x in batch],
|
331 |
+
tgt_texts=[x["tgt_texts"] for x in batch],
|
332 |
+
max_length=self.data_args.max_source_length,
|
333 |
+
max_target_length=self.data_args.max_target_length,
|
334 |
+
padding="max_length" if self.tpu_num_cores is not None else "longest", # TPU hack
|
335 |
+
return_tensors="pt",
|
336 |
+
**self.dataset_kwargs,
|
337 |
+
)
|
338 |
+
return batch_encoding.data
|
339 |
+
|
340 |
+
|
341 |
+
class SortishSampler(Sampler):
|
342 |
+
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
343 |
+
|
344 |
+
def __init__(self, data, batch_size, shuffle=True):
|
345 |
+
self.data, self.bs, self.shuffle = data, batch_size, shuffle
|
346 |
+
|
347 |
+
def __len__(self) -> int:
|
348 |
+
return len(self.data)
|
349 |
+
|
350 |
+
def __iter__(self):
|
351 |
+
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
|
352 |
+
|
353 |
+
|
354 |
+
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
|
355 |
+
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
|
356 |
+
if not shuffle:
|
357 |
+
return np.argsort(np.array(data) * -1)
|
358 |
+
|
359 |
+
def key_fn(i):
|
360 |
+
return data[i]
|
361 |
+
|
362 |
+
idxs = np.random.permutation(len(data))
|
363 |
+
sz = bs * 50
|
364 |
+
ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
|
365 |
+
sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
|
366 |
+
sz = bs
|
367 |
+
ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
|
368 |
+
max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
|
369 |
+
ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first.
|
370 |
+
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int)
|
371 |
+
sort_idx = np.concatenate((ck_idx[0], sort_idx))
|
372 |
+
return sort_idx
|
373 |
+
|
374 |
+
|
375 |
+
class DistributedSortishSampler(Sampler):
|
376 |
+
"""Copied from torch DistributedSampler"""
|
377 |
+
|
378 |
+
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
|
379 |
+
if num_replicas is None:
|
380 |
+
if not dist.is_available():
|
381 |
+
raise RuntimeError("Requires distributed package to be available")
|
382 |
+
num_replicas = dist.get_world_size()
|
383 |
+
if rank is None:
|
384 |
+
if not dist.is_available():
|
385 |
+
raise RuntimeError("Requires distributed package to be available")
|
386 |
+
rank = dist.get_rank()
|
387 |
+
self.dataset = dataset
|
388 |
+
self.num_replicas = num_replicas
|
389 |
+
self.rank = rank
|
390 |
+
self.epoch = 0
|
391 |
+
if add_extra_examples:
|
392 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
393 |
+
self.total_size = self.num_samples * self.num_replicas
|
394 |
+
else:
|
395 |
+
self.total_size = len(dataset)
|
396 |
+
self.num_samples = len(self.available_indices)
|
397 |
+
self.batch_size = batch_size
|
398 |
+
self.add_extra_examples = add_extra_examples
|
399 |
+
self.shuffle = shuffle
|
400 |
+
|
401 |
+
def __iter__(self) -> Iterable:
|
402 |
+
g = torch.Generator()
|
403 |
+
g.manual_seed(self.epoch)
|
404 |
+
|
405 |
+
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
|
406 |
+
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
|
407 |
+
indices = [self.available_indices[i] for i in sortish_indices]
|
408 |
+
assert len(indices) == self.num_samples
|
409 |
+
return iter(indices)
|
410 |
+
|
411 |
+
@cached_property
|
412 |
+
def available_indices(self) -> np.array:
|
413 |
+
indices = list(range(len(self.dataset)))
|
414 |
+
# add extra samples to make it evenly divisible
|
415 |
+
indices += indices[: (self.total_size - len(indices))]
|
416 |
+
assert len(indices) == self.total_size
|
417 |
+
# subsample
|
418 |
+
available_indices = indices[self.rank : self.total_size : self.num_replicas]
|
419 |
+
return available_indices
|
420 |
+
|
421 |
+
def __len__(self):
|
422 |
+
return self.num_samples
|
423 |
+
|
424 |
+
def set_epoch(self, epoch):
|
425 |
+
self.epoch = epoch
|
426 |
+
|
427 |
+
|
428 |
+
logger = getLogger(__name__)
|
429 |
+
|
430 |
+
|
431 |
+
def use_task_specific_params(model, task):
|
432 |
+
"""Update config with summarization specific params."""
|
433 |
+
task_specific_params = model.config.task_specific_params
|
434 |
+
|
435 |
+
if task_specific_params is not None:
|
436 |
+
pars = task_specific_params.get(task, {})
|
437 |
+
logger.info(f"using task specific params for {task}: {pars}")
|
438 |
+
model.config.update(pars)
|
439 |
+
|
440 |
+
|
441 |
+
def pickle_load(path):
|
442 |
+
"""pickle.load(path)"""
|
443 |
+
with open(path, "rb") as f:
|
444 |
+
return pickle.load(f)
|
445 |
+
|
446 |
+
|
447 |
+
def pickle_save(obj, path):
|
448 |
+
"""pickle.dump(obj, path)"""
|
449 |
+
with open(path, "wb") as f:
|
450 |
+
return pickle.dump(obj, f)
|
451 |
+
|
452 |
+
|
453 |
+
def flatten_list(summary_ids: List[List]):
|
454 |
+
return [x for x in itertools.chain.from_iterable(summary_ids)]
|
455 |
+
|
456 |
+
|
457 |
+
def save_json(content, path, indent=4, **json_dump_kwargs):
|
458 |
+
with open(path, "w") as f:
|
459 |
+
json.dump(content, f, indent=indent, **json_dump_kwargs)
|
460 |
+
|
461 |
+
|
462 |
+
def load_json(path):
|
463 |
+
with open(path) as f:
|
464 |
+
return json.load(f)
|
465 |
+
|
466 |
+
|
467 |
+
ROUGE_KEYS = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
|
468 |
+
|
469 |
+
|
470 |
+
def extract_rouge_mid_statistics(dct):
|
471 |
+
new_dict = {}
|
472 |
+
for k1, v1 in dct.items():
|
473 |
+
mid = v1.mid
|
474 |
+
new_dict[k1] = {stat: round(getattr(mid, stat), 4) for stat in ["precision", "recall", "fmeasure"]}
|
475 |
+
return new_dict
|
476 |
+
|
477 |
+
|
478 |
+
def calculate_rouge(
|
479 |
+
pred_lns: List[str],
|
480 |
+
tgt_lns: List[str],
|
481 |
+
use_stemmer=True,
|
482 |
+
rouge_keys=ROUGE_KEYS,
|
483 |
+
return_precision_and_recall=False,
|
484 |
+
bootstrap_aggregation=True,
|
485 |
+
newline_sep=True,
|
486 |
+
) -> Dict:
|
487 |
+
"""Calculate rouge using rouge_scorer package.
|
488 |
+
|
489 |
+
Args:
|
490 |
+
pred_lns: list of summaries generated by model
|
491 |
+
tgt_lns: list of groundtruth summaries (e.g. contents of val.target)
|
492 |
+
use_stemmer: Bool indicating whether Porter stemmer should be used to
|
493 |
+
strip word suffixes to improve matching.
|
494 |
+
rouge_keys: which metrics to compute, defaults to rouge1, rouge2, rougeL, rougeLsum
|
495 |
+
return_precision_and_recall: (False) whether to also return precision and recall.
|
496 |
+
bootstrap_aggregation: whether to do the typical bootstrap resampling of scores. Defaults to True, if False
|
497 |
+
this function returns a collections.defaultdict[metric: list of values for each observation for each subscore]``
|
498 |
+
newline_sep:(default=True) whether to add newline between sentences. This is essential for calculation rougeL
|
499 |
+
on multi sentence summaries (CNN/DM dataset).
|
500 |
+
|
501 |
+
Returns:
|
502 |
+
Dict[score: value] if aggregate else defaultdict(list) keyed by rouge_keys
|
503 |
+
|
504 |
+
"""
|
505 |
+
scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=use_stemmer)
|
506 |
+
aggregator = scoring.BootstrapAggregator()
|
507 |
+
for pred, tgt in zip(tgt_lns, pred_lns):
|
508 |
+
# rougeLsum expects "\n" separated sentences within a summary
|
509 |
+
if newline_sep:
|
510 |
+
pred = add_newline_to_end_of_each_sentence(pred)
|
511 |
+
tgt = add_newline_to_end_of_each_sentence(tgt)
|
512 |
+
scores = scorer.score(pred, tgt)
|
513 |
+
aggregator.add_scores(scores)
|
514 |
+
|
515 |
+
if bootstrap_aggregation:
|
516 |
+
result = aggregator.aggregate()
|
517 |
+
if return_precision_and_recall:
|
518 |
+
return extract_rouge_mid_statistics(result) # here we return dict
|
519 |
+
else:
|
520 |
+
return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
|
521 |
+
|
522 |
+
else:
|
523 |
+
return aggregator._scores # here we return defaultdict(list)
|
524 |
+
|
525 |
+
|
526 |
+
# Utilities for freezing parameters and checking whether they are frozen
|
527 |
+
|
528 |
+
|
529 |
+
def freeze_params(model: nn.Module):
|
530 |
+
"""Set requires_grad=False for each of model.parameters()"""
|
531 |
+
for par in model.parameters():
|
532 |
+
par.requires_grad = False
|
533 |
+
|
534 |
+
|
535 |
+
def freeze_embeds(model):
|
536 |
+
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
537 |
+
model_type = model.config.model_type
|
538 |
+
|
539 |
+
if model_type == "t5":
|
540 |
+
freeze_params(model.shared)
|
541 |
+
for d in [model.encoder, model.decoder]:
|
542 |
+
freeze_params(d.embed_tokens)
|
543 |
+
elif model_type == "fsmt":
|
544 |
+
for d in [model.model.encoder, model.model.decoder]:
|
545 |
+
freeze_params(d.embed_positions)
|
546 |
+
freeze_params(d.embed_tokens)
|
547 |
+
else:
|
548 |
+
freeze_params(model.model.shared)
|
549 |
+
for d in [model.model.encoder, model.model.decoder]:
|
550 |
+
freeze_params(d.embed_positions)
|
551 |
+
freeze_params(d.embed_tokens)
|
552 |
+
|
553 |
+
|
554 |
+
def grad_status(model: nn.Module) -> Iterable:
|
555 |
+
return (par.requires_grad for par in model.parameters())
|
556 |
+
|
557 |
+
|
558 |
+
def any_requires_grad(model: nn.Module) -> bool:
|
559 |
+
return any(grad_status(model))
|
560 |
+
|
561 |
+
|
562 |
+
def assert_all_frozen(model):
|
563 |
+
model_grads: List[bool] = list(grad_status(model))
|
564 |
+
n_require_grad = sum(lmap(int, model_grads))
|
565 |
+
npars = len(model_grads)
|
566 |
+
assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
|
567 |
+
|
568 |
+
|
569 |
+
def assert_not_all_frozen(model):
|
570 |
+
model_grads: List[bool] = list(grad_status(model))
|
571 |
+
npars = len(model_grads)
|
572 |
+
assert any(model_grads), f"none of {npars} weights require grad"
|
573 |
+
|
574 |
+
|
575 |
+
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
|
576 |
+
"""
|
577 |
+
Parse an argv list of unspecified command line args to a dict.
|
578 |
+
Assumes all values are either numeric or boolean in the form of true/false.
|
579 |
+
"""
|
580 |
+
result = {}
|
581 |
+
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
|
582 |
+
num_pairs = len(unparsed_args) // 2
|
583 |
+
for pair_num in range(num_pairs):
|
584 |
+
i = 2 * pair_num
|
585 |
+
assert unparsed_args[i].startswith("--")
|
586 |
+
if unparsed_args[i + 1].lower() == "true":
|
587 |
+
value = True
|
588 |
+
elif unparsed_args[i + 1].lower() == "false":
|
589 |
+
value = False
|
590 |
+
else:
|
591 |
+
try:
|
592 |
+
value = int(unparsed_args[i + 1])
|
593 |
+
except ValueError:
|
594 |
+
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
|
595 |
+
|
596 |
+
result[unparsed_args[i][2:]] = value
|
597 |
+
return result
|
598 |
+
|
599 |
+
|
600 |
+
def write_txt_file(ordered_tgt, path):
|
601 |
+
f = Path(path).open("w")
|
602 |
+
for ln in ordered_tgt:
|
603 |
+
f.write(ln + "\n")
|
604 |
+
f.flush()
|
605 |
+
|
606 |
+
|
607 |
+
def chunks(lst, n):
|
608 |
+
"""Yield successive n-sized chunks from lst."""
|
609 |
+
for i in range(0, len(lst), n):
|
610 |
+
yield lst[i : i + n]
|
utils/verbalisation_module.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.finetune import Graph2TextModule
|
2 |
+
from typing import Dict, List, Tuple, Union, Optional
|
3 |
+
import torch
|
4 |
+
import re
|
5 |
+
|
6 |
+
if torch.cuda.is_available():
|
7 |
+
DEVICE = 'cuda'
|
8 |
+
else:
|
9 |
+
DEVICE = 'cpu'
|
10 |
+
print('CUDA NOT AVAILABLE')
|
11 |
+
|
12 |
+
CHECKPOINT = 'base/t5-base_13881_val_avg_bleu=68.1000-step_count=5.ckpt'
|
13 |
+
MAX_LENGTH = 384
|
14 |
+
SEED = 42
|
15 |
+
|
16 |
+
|
17 |
+
class VerbModule():
|
18 |
+
|
19 |
+
def __init__(self, override_args: Dict[str, str] = None):
|
20 |
+
# Model
|
21 |
+
if not override_args:
|
22 |
+
override_args = {}
|
23 |
+
self.g2t_module = Graph2TextModule.load_from_checkpoint(CHECKPOINT, strict=False, **override_args)
|
24 |
+
self.tokenizer = self.g2t_module.tokenizer
|
25 |
+
# Unk replacer
|
26 |
+
self.vocab = self.tokenizer.get_vocab()
|
27 |
+
self.convert_some_japanese_characters = True
|
28 |
+
self.unk_char_replace_sliding_window_size = 2
|
29 |
+
self.unknowns = []
|
30 |
+
|
31 |
+
def __generate_verbalisations_from_inputs(self, inputs: Union[str, List[str]]):
|
32 |
+
try:
|
33 |
+
inputs_encoding = self.tokenizer.prepare_seq2seq_batch(
|
34 |
+
inputs, truncation=True, max_length=MAX_LENGTH, return_tensors='pt'
|
35 |
+
)
|
36 |
+
inputs_encoding = {k: v.to(DEVICE) for k, v in inputs_encoding.items()}
|
37 |
+
|
38 |
+
self.g2t_module.model.eval()
|
39 |
+
with torch.no_grad():
|
40 |
+
gen_output = self.g2t_module.model.generate(
|
41 |
+
inputs_encoding['input_ids'],
|
42 |
+
attention_mask=inputs_encoding['attention_mask'],
|
43 |
+
use_cache=True,
|
44 |
+
decoder_start_token_id = self.g2t_module.decoder_start_token_id,
|
45 |
+
num_beams= self.g2t_module.eval_beams,
|
46 |
+
max_length= self.g2t_module.eval_max_length,
|
47 |
+
length_penalty=1.0
|
48 |
+
)
|
49 |
+
except Exception:
|
50 |
+
print(inputs)
|
51 |
+
raise
|
52 |
+
|
53 |
+
return gen_output
|
54 |
+
|
55 |
+
'''
|
56 |
+
We create this function as an alteration from [this one](https://github.com/huggingface/transformers/blob/198c335d219a5eb4d3f124fdd1ce1a9cd9f78a9b/src/transformers/tokenization_utils_fast.py#L537), mainly because the official 'tokenizer.decode' treats all special tokens the same, while we want to drop all special tokens from the decoded sentence EXCEPT for the <unk> token, which we will replace later on.
|
57 |
+
'''
|
58 |
+
def __decode_ids_to_string_custom(
|
59 |
+
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
60 |
+
) -> str:
|
61 |
+
filtered_tokens = self.tokenizer.convert_ids_to_tokens(token_ids, skip_special_tokens=False)
|
62 |
+
# Do not remove special tokens yet
|
63 |
+
|
64 |
+
# To avoid mixing byte-level and unicode for byte-level BPT
|
65 |
+
# we need to build string separatly for added tokens and byte-level tokens
|
66 |
+
# cf. https://github.com/huggingface/transformers/issues/1133
|
67 |
+
sub_texts = []
|
68 |
+
current_sub_text = []
|
69 |
+
for token in filtered_tokens:
|
70 |
+
if skip_special_tokens and\
|
71 |
+
token != self.tokenizer.unk_token and\
|
72 |
+
token in self.tokenizer.all_special_tokens:
|
73 |
+
|
74 |
+
continue
|
75 |
+
else:
|
76 |
+
current_sub_text.append(token)
|
77 |
+
if current_sub_text:
|
78 |
+
sub_texts.append(self.tokenizer.convert_tokens_to_string(current_sub_text))
|
79 |
+
text = " ".join(sub_texts)
|
80 |
+
|
81 |
+
if clean_up_tokenization_spaces:
|
82 |
+
clean_text = self.tokenizer.clean_up_tokenization(text)
|
83 |
+
return clean_text
|
84 |
+
else:
|
85 |
+
return text
|
86 |
+
|
87 |
+
def __decode_sentences(self, encoded_sentences: Union[str, List[str]]):
|
88 |
+
if type(encoded_sentences) == str:
|
89 |
+
encoded_sentences = [encoded_sentences]
|
90 |
+
decoded_sentences = [self.__decode_ids_to_string_custom(i, skip_special_tokens=True) for i in encoded_sentences]
|
91 |
+
return decoded_sentences
|
92 |
+
|
93 |
+
def verbalise_sentence(self, inputs: Union[str, List[str]]):
|
94 |
+
if type(inputs) == str:
|
95 |
+
inputs = [inputs]
|
96 |
+
|
97 |
+
gen_output = self.__generate_verbalisations_from_inputs(inputs)
|
98 |
+
|
99 |
+
decoded_sentences = self.__decode_sentences(gen_output)
|
100 |
+
|
101 |
+
if len(decoded_sentences) == 1:
|
102 |
+
return decoded_sentences[0]
|
103 |
+
else:
|
104 |
+
return decoded_sentences
|
105 |
+
|
106 |
+
def verbalise_triples(self, input_triples: Union[Dict[str, str], List[Dict[str, str]], List[List[Dict[str, str]]]]):
|
107 |
+
if type(input_triples) == dict:
|
108 |
+
input_triples = [input_triples]
|
109 |
+
|
110 |
+
verbalisation_inputs = []
|
111 |
+
for triple in input_triples:
|
112 |
+
if type(triple) == dict:
|
113 |
+
assert 'subject' in triple
|
114 |
+
assert 'predicate' in triple
|
115 |
+
assert 'object' in triple
|
116 |
+
verbalisation_inputs.append(
|
117 |
+
f'translate Graph to English: <H> {triple["subject"]} <R> {triple["predicate"]} <T> {triple["object"]}'
|
118 |
+
)
|
119 |
+
elif type(triple) == list:
|
120 |
+
input_sentence = ['translate Graph to English:']
|
121 |
+
for subtriple in triple:
|
122 |
+
assert 'subject' in subtriple
|
123 |
+
assert 'predicate' in subtriple
|
124 |
+
assert 'object' in subtriple
|
125 |
+
input_sentence.append(f'<H> {subtriple["subject"]}')
|
126 |
+
input_sentence.append(f'<R> {subtriple["predicate"]}')
|
127 |
+
input_sentence.append(f'<T> {subtriple["object"]}')
|
128 |
+
verbalisation_inputs.append(
|
129 |
+
' '.join(input_sentence)
|
130 |
+
)
|
131 |
+
|
132 |
+
return self.verbalise_sentence(verbalisation_inputs)
|
133 |
+
|
134 |
+
def verbalise(self, input: Union[str, List, Dict]):
|
135 |
+
try:
|
136 |
+
if (type(input) == str) or (type(input) == list and type(input[0]) == str):
|
137 |
+
return self.verbalise_sentence(input)
|
138 |
+
elif (type(input) == dict) or (type(input) == list and type(input[0]) == dict):
|
139 |
+
return self.verbalise_triples(input)
|
140 |
+
else:
|
141 |
+
return self.verbalise_triples(input)
|
142 |
+
except Exception:
|
143 |
+
print(f'ERROR VERBALISING {input}')
|
144 |
+
raise
|
145 |
+
|
146 |
+
def add_label_to_unk_replacer(self, label: str):
|
147 |
+
N = self.unk_char_replace_sliding_window_size
|
148 |
+
self.unknowns.append({})
|
149 |
+
|
150 |
+
# Some pre-processing of labels to normalise some characters
|
151 |
+
if self.convert_some_japanese_characters:
|
152 |
+
label = label.replace('(','(')
|
153 |
+
label = label.replace(')',')')
|
154 |
+
label = label.replace('〈','<')
|
155 |
+
label = label.replace('/','/')
|
156 |
+
label = label.replace('〉','>')
|
157 |
+
|
158 |
+
label_encoded = self.tokenizer.encode(label)
|
159 |
+
label_tokens = self.tokenizer.convert_ids_to_tokens(label_encoded)
|
160 |
+
|
161 |
+
# Here, we also remove </s> (eos) and <pad> tokens in the replacing key, because:
|
162 |
+
# 1) When the whole label is all unk:
|
163 |
+
# label_token_to_string would be '<unk></s>', meaning the replacing key (which is the same) only replaces
|
164 |
+
# the <unk> if it appears at the end of the sentence, which is not the desired effect.
|
165 |
+
# But since this means ANY <unk> will be replaced by this, it would be good to only replace keys that are <unk>
|
166 |
+
# on the last replacing pass.
|
167 |
+
# 2) On other cases, then the unk is in the label but not in its entirety, like in the start/end, it might
|
168 |
+
# involve the starting <pad> token or the ending <eos> token on the replacing key, again forcing the replacement
|
169 |
+
# to only happen if the label appears in the end of the sentence.
|
170 |
+
label_tokens = [t for t in label_tokens if t not in [
|
171 |
+
self.tokenizer.eos_token, self.tokenizer.pad_token
|
172 |
+
]]
|
173 |
+
|
174 |
+
label_token_to_string = self.tokenizer.convert_tokens_to_string(label_tokens)
|
175 |
+
unk_token_to_string = self.tokenizer.convert_tokens_to_string([self.tokenizer.unk_token])
|
176 |
+
|
177 |
+
#print(label_encoded,label_tokens,label_token_to_string)
|
178 |
+
|
179 |
+
match_unks_in_label = re.findall('(?:(?: )*<unk>(?: )*)+', label_token_to_string)
|
180 |
+
if len(match_unks_in_label) > 0:
|
181 |
+
# If the whole label is made of UNK
|
182 |
+
if (match_unks_in_label[0]) == label_token_to_string:
|
183 |
+
#print('Label is all unks')
|
184 |
+
self.unknowns[-1][label_token_to_string.strip()] = label
|
185 |
+
# Else, there should be non-UNK characters in the label
|
186 |
+
else:
|
187 |
+
#print('Label is NOT all unks')
|
188 |
+
# Analyse the label with a sliding window of size N (N before, N ahead)
|
189 |
+
for idx, token in enumerate(label_tokens):
|
190 |
+
idx_before = max(0,idx-N)
|
191 |
+
idx_ahead = min(len(label_tokens), idx+N+1)
|
192 |
+
|
193 |
+
|
194 |
+
# Found a UNK
|
195 |
+
if token == self.tokenizer.unk_token:
|
196 |
+
|
197 |
+
# In case multiple UNK, exclude UNKs seen after this one, expand window to other side if possible
|
198 |
+
if len(match_unks_in_label) > 1:
|
199 |
+
#print(idx)
|
200 |
+
#print(label_tokens)
|
201 |
+
#print(label_tokens[idx_before:idx_ahead])
|
202 |
+
#print('HERE!')
|
203 |
+
# Reduce on the right, expanding on the left
|
204 |
+
while self.tokenizer.unk_token in label_tokens[idx+1:idx_ahead]:
|
205 |
+
idx_before = max(0,idx_before-1)
|
206 |
+
idx_ahead = min(idx+2, idx_ahead-1)
|
207 |
+
#print(label_tokens[idx_before:idx_ahead])
|
208 |
+
# Now just reduce on the left
|
209 |
+
while self.tokenizer.unk_token in label_tokens[idx_before:idx]:
|
210 |
+
idx_before = min(idx-1,idx_before+2)
|
211 |
+
#print(label_tokens[idx_before:idx_ahead])
|
212 |
+
|
213 |
+
span = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx_ahead])
|
214 |
+
# First token of the label is UNK
|
215 |
+
if idx == 1 and label_tokens[0] == '▁':
|
216 |
+
#print('Label begins with unks')
|
217 |
+
to_replace = '^' + re.escape(span).replace(
|
218 |
+
re.escape(unk_token_to_string),
|
219 |
+
'.+?'
|
220 |
+
)
|
221 |
+
|
222 |
+
replaced_span = re.search(
|
223 |
+
to_replace,
|
224 |
+
label
|
225 |
+
)[0]
|
226 |
+
self.unknowns[-1][span.strip()] = replaced_span
|
227 |
+
# Last token of the label is UNK
|
228 |
+
elif idx == len(label_tokens)-2 and label_tokens[-1] == self.tokenizer.eos_token:
|
229 |
+
#print('Label ends with unks')
|
230 |
+
pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])
|
231 |
+
pre_idx_unk_counts = pre_idx.count(unk_token_to_string)
|
232 |
+
to_replace = re.escape(span).replace(
|
233 |
+
re.escape(unk_token_to_string),
|
234 |
+
f'[^{re.escape(pre_idx)}]+?'
|
235 |
+
) + '$'
|
236 |
+
|
237 |
+
if pre_idx.strip() == '':
|
238 |
+
to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
|
239 |
+
|
240 |
+
replaced_span = re.search(
|
241 |
+
to_replace,
|
242 |
+
label
|
243 |
+
)[0]
|
244 |
+
self.unknowns[-1][span.strip()] = replaced_span
|
245 |
+
|
246 |
+
# A token in-between the label is UNK
|
247 |
+
else:
|
248 |
+
#print('Label has unks in the middle')
|
249 |
+
pre_idx = self.tokenizer.convert_tokens_to_string(label_tokens[idx_before:idx])
|
250 |
+
|
251 |
+
to_replace = re.escape(span).replace(
|
252 |
+
re.escape(unk_token_to_string),
|
253 |
+
f'[^{re.escape(pre_idx)}]+?'
|
254 |
+
)
|
255 |
+
#If there is nothing behind the ??, because it is in the middle but the previous token is also
|
256 |
+
#a ??, then we would end up with to_replace beginning with [^], which we can't have
|
257 |
+
if pre_idx.strip() == '':
|
258 |
+
to_replace = to_replace.replace('[^]', '(?<=\s)[^a-zA-Z0-9]')
|
259 |
+
|
260 |
+
replaced_span = re.search(
|
261 |
+
to_replace,
|
262 |
+
label
|
263 |
+
)
|
264 |
+
|
265 |
+
if replaced_span:
|
266 |
+
span = re.sub(r'\s([?.!",](?:\s|$))', r'\1', span.strip())
|
267 |
+
self.unknowns[-1][span] = replaced_span[0]
|
268 |
+
|
269 |
+
def replace_unks_on_sentence(self, sentence: str, loop_n : int = 3, empty_after : bool = False):
|
270 |
+
# Loop through in case the labels are repeated, maximum of three times
|
271 |
+
while '<unk>' in sentence and loop_n > 0:
|
272 |
+
loop_n -= 1
|
273 |
+
for unknowns in self.unknowns:
|
274 |
+
for k,v in unknowns.items():
|
275 |
+
# Leave to replace all-unk labels at the last pass
|
276 |
+
if k == '<unk>' and loop_n > 0:
|
277 |
+
continue
|
278 |
+
# In case it is because the first letter of the sentence has been uppercased
|
279 |
+
if not k in sentence and k[0] == k[0].lower() and k[0].upper() == sentence[0]:
|
280 |
+
k = k[0].upper() + k[1:]
|
281 |
+
v = v[0].upper() + v[1:]
|
282 |
+
# In case it is because a double space is found where it should not be
|
283 |
+
elif not k in sentence and len(re.findall(r'\s{2,}',k))>0:
|
284 |
+
k = re.sub(r'\s+', ' ', k)
|
285 |
+
#print(k,'/',v,'/',sentence)
|
286 |
+
sentence = sentence.replace(k.strip(),v.strip(),1)
|
287 |
+
#sentence = re.sub(k, v, sentence)
|
288 |
+
# Removing final doublespaces
|
289 |
+
sentence = re.sub(r'\s+', ' ', sentence).strip()
|
290 |
+
# Removing spaces before punctuation
|
291 |
+
sentence = re.sub(r'\s([?.!",](?:\s|$))', r'\1', sentence)
|
292 |
+
if empty_after:
|
293 |
+
self.unknowns = []
|
294 |
+
return sentence
|
295 |
+
|
296 |
+
if __name__ == '__main__':
|
297 |
+
|
298 |
+
verb_module = VerbModule()
|
299 |
+
verbs = verb_module.verbalise('translate Graph to English: <H> World Trade Center <R> height <T> 200 meter <H> World Trade Center <R> is a <T> tower')
|
300 |
+
print(verbs)
|
utils/wikidata_utils.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
import uuid
|
4 |
+
import numpy as np
|
5 |
+
import time
|
6 |
+
import requests
|
7 |
+
import traceback
|
8 |
+
import pdb
|
9 |
+
import math
|
10 |
+
import ast
|
11 |
+
import pandas as pd
|
12 |
+
import pickle
|
13 |
+
from qwikidata.linked_data_interface import get_entity_dict_from_api
|
14 |
+
from qwikidata.sparql import return_sparql_query_results
|
15 |
+
|
16 |
+
from urllib3.exceptions import MaxRetryError, ConnectionError
|
17 |
+
from qwikidata.linked_data_interface import LdiResponseNotOk
|
18 |
+
|
19 |
+
import hashlib
|
20 |
+
|
21 |
+
class CachedWikidataAPI():
|
22 |
+
|
23 |
+
def __init__(self, cache_path = 'entity_cache.p', save_every_x_queries=1):
|
24 |
+
self.save_every_x_queries = save_every_x_queries
|
25 |
+
self.x_queries_passed = 0
|
26 |
+
self.languages = ['en','fr','es','pt','pt-br','it','de']
|
27 |
+
self.cache_path = cache_path
|
28 |
+
try:
|
29 |
+
with open(self.cache_path,'rb') as f:
|
30 |
+
self.entity_cache = pickle.load(f)
|
31 |
+
except FileNotFoundError:
|
32 |
+
self.entity_cache = {}
|
33 |
+
|
34 |
+
def get_unique_id_from_str(self, my_str):
|
35 |
+
return hashlib.md5(str.encode(my_str)).hexdigest()
|
36 |
+
|
37 |
+
def save_entity_cache(self, force=False):
|
38 |
+
if force:
|
39 |
+
self.x_queries_passed = self.save_every_x_queries
|
40 |
+
self.x_queries_passed = self.x_queries_passed+1
|
41 |
+
if self.x_queries_passed >= self.save_every_x_queries:
|
42 |
+
with open(self.cache_path,'wb') as f:
|
43 |
+
pickle.dump(self.entity_cache,f)
|
44 |
+
self.x_queries_passed = 0
|
45 |
+
|
46 |
+
def get_entity(self, item_id):
|
47 |
+
if item_id in self.entity_cache:
|
48 |
+
return self.entity_cache[item_id]
|
49 |
+
while True:
|
50 |
+
try:
|
51 |
+
entity = get_entity_dict_from_api(item_id)
|
52 |
+
self.entity_cache[item_id] = entity
|
53 |
+
self.save_entity_cache()
|
54 |
+
return entity
|
55 |
+
except (ConnectionError, MaxRetryError) as e:
|
56 |
+
#traceback.print_exc()
|
57 |
+
time.sleep(1)
|
58 |
+
continue
|
59 |
+
except LdiResponseNotOk:
|
60 |
+
#traceback.print_exc()
|
61 |
+
self.entity_cache[item_id] = 'deleted'
|
62 |
+
self.save_entity_cache()
|
63 |
+
return 'deleted'
|
64 |
+
|
65 |
+
def get_label(self, item, non_language_set=False):
|
66 |
+
if type(item) == str:
|
67 |
+
entity = self.get_entity(item)
|
68 |
+
if entity == 'deleted':
|
69 |
+
return (entity, 'none')
|
70 |
+
labels = entity['labels' if 'labels' in entity else 'lemmas']
|
71 |
+
elif type(item) == dict:
|
72 |
+
if 'labels' in item:
|
73 |
+
labels = item['labels']
|
74 |
+
elif 'lemmas' in item:
|
75 |
+
labels = item['lemmas']
|
76 |
+
for l in self.languages:
|
77 |
+
if l in labels:
|
78 |
+
return (labels[l]['value'], l)
|
79 |
+
if non_language_set:
|
80 |
+
all_labels = list(labels.keys())
|
81 |
+
if len(all_labels)>0:
|
82 |
+
return (labels[all_labels[0]]['value'], all_labels[0])
|
83 |
+
return ('no-label', 'none')
|
84 |
+
|
85 |
+
def get_desc(self, item, non_language_set=False):
|
86 |
+
if type(item) == str:
|
87 |
+
entity = self.get_entity(item)
|
88 |
+
if entity == 'deleted':
|
89 |
+
return (entity, 'none')
|
90 |
+
descriptions = entity['descriptions']
|
91 |
+
elif type(item) == dict:
|
92 |
+
if 'descriptions' in item:
|
93 |
+
descriptions = item['descriptions']
|
94 |
+
for l in self.languages:
|
95 |
+
if l in descriptions:
|
96 |
+
return (descriptions[l]['value'], l)
|
97 |
+
if non_language_set:
|
98 |
+
all_descriptions = list(descriptions.keys())
|
99 |
+
if len(all_descriptions)>0:
|
100 |
+
return (descriptions[all_descriptions[0]]['value'], all_descriptions[0])
|
101 |
+
return ('no-desc', 'none')
|
102 |
+
|
103 |
+
def get_alias(self, item, non_language_set=False):
|
104 |
+
if type(item) == str:
|
105 |
+
entity = self.get_entity(item)
|
106 |
+
if entity == 'deleted':
|
107 |
+
return ([entity], 'none')
|
108 |
+
aliases = entity['aliases']
|
109 |
+
elif type(item) == dict:
|
110 |
+
if 'aliases' in item:
|
111 |
+
aliases = item['aliases']
|
112 |
+
for l in self.languages:
|
113 |
+
if l in aliases:
|
114 |
+
return ([alias['value'] for alias in aliases[l]], l)
|
115 |
+
if non_language_set:
|
116 |
+
all_aliases = list(aliases.keys())
|
117 |
+
if len(all_aliases)>0:
|
118 |
+
return (aliases[all_aliases[0]]['value'], all_aliases[0])
|
119 |
+
return ([alias['value'] for alias in aliases[all_aliases[0]]], all_aliases[0])
|
120 |
+
return ('no-alias', 'none')
|
121 |
+
|
122 |
+
def get_datatype(self, item):
|
123 |
+
try:
|
124 |
+
if type(item) == str:
|
125 |
+
entity = self.get_entity(item)
|
126 |
+
if entity == 'deleted':
|
127 |
+
return entity
|
128 |
+
datatype = entity['datatype']
|
129 |
+
elif type(item) == dict:
|
130 |
+
datatype = item['datatype']
|
131 |
+
return datatype
|
132 |
+
except KeyError:
|
133 |
+
return 'none'
|
134 |
+
|
135 |
+
def get_claim_values_of(self, item, property_id):
|
136 |
+
if type(item) == str:
|
137 |
+
entity = self.get_entity(item)
|
138 |
+
if entity == 'deleted':
|
139 |
+
return entity
|
140 |
+
claims = entity['claims']
|
141 |
+
elif type(item) == dict:
|
142 |
+
claims = item['claims']
|
143 |
+
if property_id in claims:
|
144 |
+
instance_of_claims = claims[property_id]
|
145 |
+
return [i['mainsnak']['datavalue']['value']['id'] for i in instance_of_claims]
|
146 |
+
else:
|
147 |
+
return []
|
148 |
+
|
149 |
+
def query_sparql_endpoint(self, sparql_query):
|
150 |
+
sparql_query_id = self.get_unique_id_from_str(sparql_query)
|
151 |
+
if sparql_query_id in self.entity_cache:
|
152 |
+
return self.entity_cache[sparql_query_id]
|
153 |
+
else:
|
154 |
+
wikidata_sparql_url = 'https://query.wikidata.org/sparql'
|
155 |
+
try:
|
156 |
+
while True:
|
157 |
+
res = requests.get(wikidata_sparql_url, params={"query": sparql_query, "format": "json"})
|
158 |
+
if res.status_code in (429,504):
|
159 |
+
time.sleep(1)
|
160 |
+
continue
|
161 |
+
elif res.status_code == 200:
|
162 |
+
res = res.json()
|
163 |
+
self.entity_cache[sparql_query_id] = res
|
164 |
+
self.save_entity_cache()
|
165 |
+
return res
|
166 |
+
else:
|
167 |
+
print(res.status_code)
|
168 |
+
raise Exception
|
169 |
+
except json.JSONDecodeError as e:
|
170 |
+
#pdb.set_trace()
|
171 |
+
print(res, res.__dict__)
|
172 |
+
raise e
|
173 |
+
|