HaloMaster's picture
add fengshen
50f0fbb
import collections
import numpy as np
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def is_start_piece(piece):
"""Check if the current word piece is the starting piece (BERT)."""
# When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
return not piece.startswith("##")
def create_masked_lm_predictions(tokens,
vocab_id_list, vocab_id_to_token_dict,
masked_lm_prob,
cls_id, sep_id, mask_id,
max_predictions_per_seq,
np_rng,
max_ngrams=3,
do_whole_word_mask=True,
favor_longer_ngram=False,
do_permutation=False,
geometric_dist=False,
masking_style="bert",
zh_tokenizer=None):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
'''
modified from Megatron-LM
Args:
tokens: 输入
vocab_id_list: 词表token_id_list
vocab_id_to_token_dict: token_id到token字典
masked_lm_prob:mask概率
cls_id、sep_id、mask_id:特殊token
max_predictions_per_seq:最大mask个数
np_rng:mask随机数
max_ngrams:最大词长度
do_whole_word_mask:是否做全词掩码
favor_longer_ngram:优先用长的词
do_permutation:是否打乱
geometric_dist:用np_rng.geometric做随机
masking_style:mask类型
zh_tokenizer:WWM的分词器,比如用jieba.lcut做分词之类的
'''
cand_indexes = []
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary = [0] * len(tokens)
# 如果没有指定中文分词器,那就直接按##算
if zh_tokenizer is None:
for (i, token) in enumerate(tokens):
if token == cls_id or token == sep_id:
token_boundary[i] = 1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (do_whole_word_mask and len(cand_indexes) >= 1 and
not is_start_piece(vocab_id_to_token_dict[token])):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
if is_start_piece(vocab_id_to_token_dict[token]):
token_boundary[i] = 1
else:
# 如果指定了中文分词器,那就先用分词器分词,然后再进行判断
# 获取去掉CLS SEP的原始文本
raw_tokens = []
for t in tokens:
if t != cls_id and t != sep_id:
raw_tokens.append(t)
raw_tokens = [vocab_id_to_token_dict[i] for i in raw_tokens]
# 分词然后获取每次字开头的最长词的长度
word_list = set(zh_tokenizer(''.join(raw_tokens), HMM=True))
word_length_dict = {}
for w in word_list:
if len(w) < 1:
continue
if w[0] not in word_length_dict:
word_length_dict[w[0]] = len(w)
elif word_length_dict[w[0]] < len(w):
word_length_dict[w[0]] = len(w)
i = 0
# 从词表里面检索
while i < len(tokens):
token_id = tokens[i]
token = vocab_id_to_token_dict[token_id]
if len(token) == 0 or token_id == cls_id or token_id == sep_id:
token_boundary[i] = 1
i += 1
continue
word_max_length = 1
if token[0] in word_length_dict:
word_max_length = word_length_dict[token[0]]
j = 0
word = ''
word_end = i+1
# 兼容以前##的形式,如果后面的词是##开头的,那么直接把后面的拼到前面当作一个词
old_style = False
while word_end < len(tokens) and vocab_id_to_token_dict[tokens[word_end]].startswith('##'):
old_style = True
word_end += 1
if not old_style:
while j < word_max_length and i+j < len(tokens):
cur_token = tokens[i+j]
word += vocab_id_to_token_dict[cur_token]
j += 1
if word in word_list:
word_end = i+j
cand_indexes.append([p for p in range(i, word_end)])
token_boundary[i] = 1
i = word_end
output_tokens = list(tokens)
masked_lm_positions = []
masked_lm_labels = []
if masked_lm_prob == 0:
return (output_tokens, masked_lm_positions,
masked_lm_labels, token_boundary)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
if not geometric_dist:
# Note(mingdachen):
# By default, we set the probilities to favor shorter ngram sequences.
pvals = 1. / np.arange(1, max_ngrams + 1)
pvals /= pvals.sum(keepdims=True)
if favor_longer_ngram:
pvals = pvals[::-1]
# 获取一个ngram的idx,对于每个word,记录他的ngram的word
ngram_indexes = []
for idx in range(len(cand_indexes)):
ngram_index = []
for n in ngrams:
ngram_index.append(cand_indexes[idx:idx + n])
ngram_indexes.append(ngram_index)
np_rng.shuffle(ngram_indexes)
(masked_lms, masked_spans) = ([], [])
covered_indexes = set()
for cand_index_set in ngram_indexes:
if len(masked_lms) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes:
continue
if not geometric_dist:
n = np_rng.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
else:
# Sampling "n" from the geometric distribution and clipping it to
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
n = min(np_rng.geometric(0.2), max_ngrams)
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while len(masked_lms) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
token_id = tokens[index]
if masking_style == "bert":
# 80% of the time, replace with [MASK]
if np_rng.random() < 0.8:
masked_token = mask_id
else:
# 10% of the time, keep original
if np_rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_id_list[np_rng.randint(0, len(vocab_id_list))]
elif masking_style == "t5":
masked_token = mask_id
else:
raise ValueError("invalid value of masking style")
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=token_id))
masked_spans.append(MaskedLmInstance(
index=index_set,
label=[tokens[index] for index in index_set]))
assert len(masked_lms) <= num_to_predict
np_rng.shuffle(ngram_indexes)
select_indexes = set()
if do_permutation:
for cand_index_set in ngram_indexes:
if len(select_indexes) >= num_to_predict:
break
if not cand_index_set:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for index_set in cand_index_set[0]:
for index in index_set:
if index in covered_indexes or index in select_indexes:
continue
n = np.random.choice(ngrams[:len(cand_index_set)],
p=pvals[:len(cand_index_set)] /
pvals[:len(cand_index_set)].sum(keepdims=True))
index_set = sum(cand_index_set[n - 1], [])
n -= 1
while len(select_indexes) + len(index_set) > num_to_predict:
if n == 0:
break
index_set = sum(cand_index_set[n - 1], [])
n -= 1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(select_indexes) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes or index in select_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
select_indexes.add(index)
assert len(select_indexes) <= num_to_predict
select_indexes = sorted(select_indexes)
permute_indexes = list(select_indexes)
np_rng.shuffle(permute_indexes)
orig_token = list(output_tokens)
for src_i, tgt_i in zip(select_indexes, permute_indexes):
output_tokens[src_i] = orig_token[tgt_i]
masked_lms.append(MaskedLmInstance(index=src_i, label=orig_token[src_i]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
# Sort the spans by the index of the first span
masked_spans = sorted(masked_spans, key=lambda x: x.index[0])
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary, masked_spans)