Spaces:
Build error
Build error
from copy import deepcopy | |
import torch | |
import dgl | |
import stanza | |
import networkx as nx | |
class Sentence2GraphParser: | |
def __init__(self, language='zh', use_gpu=False, download=False): | |
self.language = language | |
if download: | |
self.stanza_parser = stanza.Pipeline(lang=language, use_gpu=use_gpu) | |
else: | |
self.stanza_parser = stanza.Pipeline(lang=language, use_gpu=use_gpu, download_method=None) | |
def parse(self, clean_sentence=None, words=None, ph_words=None): | |
if self.language == 'zh': | |
assert words is not None and ph_words is not None | |
ret = self._parse_zh(words, ph_words) | |
elif self.language == 'en': | |
assert clean_sentence is not None | |
ret = self._parse_en(clean_sentence) | |
else: | |
raise NotImplementedError | |
return ret | |
def _parse_zh(self, words, ph_words, enable_backward_edge=True, enable_recur_edge=True, | |
enable_inter_sentence_edge=True, sequential_edge=False): | |
""" | |
words: <List of str>, each character in chinese is one item | |
ph_words: <List of str>, each character in chinese is one item, represented by the phoneme | |
Example: | |
text1 = '宝马配挂跛骡鞍,貂蝉怨枕董翁榻.' | |
words = ['<BOS>', '宝', '马', '配', '挂', '跛', '骡', '鞍', ',' | |
, '貂', '蝉', '怨', '枕', '董', '翁', '榻', '<EOS>'] | |
ph_words = ['<BOS>', 'b_ao3_|', 'm_a3_#', 'p_ei4_|', 'g_ua4_#', | |
'b_o3_#', 'l_uo2_|', 'an1', ',', 'd_iao1_|', | |
'ch_an2_#', 'van4_#', 'zh_en3_#', 'd_ong3_|', 'ueng1_#', 't_a4', '<EOS>'] | |
""" | |
words, ph_words = words[1:-1], ph_words[1:-1] # delete <BOS> and <EOS> | |
for i, p_w in enumerate(ph_words): | |
if p_w == ',': | |
# change english ',' into chinese | |
# we found it necessary in stanza's dependency parsing | |
words[i], ph_words[i] = ',', ',' | |
tmp_words = deepcopy(words) | |
num_added_space = 0 | |
for i, p_w in enumerate(ph_words): | |
if p_w.endswith("#"): | |
# add a blank after the p_w with '#', to separate words | |
tmp_words.insert(num_added_space + i + 1, " ") | |
num_added_space += 1 | |
if p_w in [',', ',']: | |
# add one blank before and after ', ', respectively | |
tmp_words.insert(num_added_space + i + 1, " ") # insert behind ',' first | |
tmp_words.insert(num_added_space + i, " ") # insert before | |
num_added_space += 2 | |
clean_text = ''.join(tmp_words).strip() | |
parser_out = self.stanza_parser(clean_text) | |
idx_to_word = {i + 1: w for i, w in enumerate(words)} | |
vocab_nodes = {} | |
vocab_idx_offset = 0 | |
for sentence in parser_out.sentences: | |
num_nodes_in_current_sentence = 0 | |
for vocab_node in sentence.words: | |
num_nodes_in_current_sentence += 1 | |
vocab_idx = vocab_node.id + vocab_idx_offset | |
vocab_text = vocab_node.text.replace(" ", "") # delete blank in vocab | |
vocab_nodes[vocab_idx] = vocab_text | |
vocab_idx_offset += num_nodes_in_current_sentence | |
# start vocab-to-word alignment | |
vocab_to_word = {} | |
current_word_idx = 1 | |
for vocab_i in vocab_nodes.keys(): | |
vocab_to_word[vocab_i] = [] | |
for w_in_vocab_i in vocab_nodes[vocab_i]: | |
if w_in_vocab_i != idx_to_word[current_word_idx]: | |
raise ValueError("Word Mismatch!") | |
vocab_to_word[vocab_i].append(current_word_idx) # add a path (vocab_node_idx, word_global_idx) | |
current_word_idx += 1 | |
# then we compute the vocab-level edges | |
if len(parser_out.sentences) > 5: | |
print("Detect more than 5 input sentence! pls check whether the sentence is too long!") | |
vocab_level_source_id, vocab_level_dest_id = [], [] | |
vocab_level_edge_types = [] | |
sentences_heads = [] | |
vocab_id_offset = 0 | |
# get forward edges | |
for s in parser_out.sentences: | |
for w in s.words: | |
w_idx = w.id + vocab_id_offset # it starts from 1, just same as binarizer | |
w_dest_idx = w.head + vocab_id_offset | |
if w.head == 0: | |
sentences_heads.append(w_idx) | |
continue | |
vocab_level_source_id.append(w_idx) | |
vocab_level_dest_id.append(w_dest_idx) | |
vocab_id_offset += len(s.words) | |
vocab_level_edge_types += [0] * len(vocab_level_source_id) | |
num_vocab = vocab_id_offset | |
# optional: get backward edges | |
if enable_backward_edge: | |
back_source, back_dest = deepcopy(vocab_level_dest_id), deepcopy(vocab_level_source_id) | |
vocab_level_source_id += back_source | |
vocab_level_dest_id += back_dest | |
vocab_level_edge_types += [1] * len(back_source) | |
# optional: get inter-sentence edges if num_sentences > 1 | |
inter_sentence_source, inter_sentence_dest = [], [] | |
if enable_inter_sentence_edge and len(sentences_heads) > 1: | |
def get_full_graph_edges(nodes): | |
tmp_edges = [] | |
for i, node_i in enumerate(nodes): | |
for j, node_j in enumerate(nodes): | |
if i == j: | |
continue | |
tmp_edges.append((node_i, node_j)) | |
return tmp_edges | |
tmp_edges = get_full_graph_edges(sentences_heads) | |
for (source, dest) in tmp_edges: | |
inter_sentence_source.append(source) | |
inter_sentence_dest.append(dest) | |
vocab_level_source_id += inter_sentence_source | |
vocab_level_dest_id += inter_sentence_dest | |
vocab_level_edge_types += [3] * len(inter_sentence_source) | |
if sequential_edge: | |
seq_source, seq_dest = list(range(1, num_vocab)) + list(range(num_vocab, 0, -1)), \ | |
list(range(2, num_vocab + 1)) + list(range(num_vocab - 1, -1, -1)) | |
vocab_level_source_id += seq_source | |
vocab_level_dest_id += seq_dest | |
vocab_level_edge_types += [4] * (num_vocab - 1) + [5] * (num_vocab - 1) | |
# Then, we use the vocab-level edges and the vocab-to-word path, to construct the word-level graph | |
num_word = len(words) | |
source_id, dest_id, edge_types = [], [], [] | |
for (vocab_start, vocab_end, vocab_edge_type) in zip(vocab_level_source_id, vocab_level_dest_id, | |
vocab_level_edge_types): | |
# connect the first word in the vocab | |
word_start = min(vocab_to_word[vocab_start]) | |
word_end = min(vocab_to_word[vocab_end]) | |
source_id.append(word_start) | |
dest_id.append(word_end) | |
edge_types.append(vocab_edge_type) | |
# sequential connection in words | |
for word_indices_in_v in vocab_to_word.values(): | |
for i, word_idx in enumerate(word_indices_in_v): | |
if i + 1 < len(word_indices_in_v): | |
source_id.append(word_idx) | |
dest_id.append(word_idx + 1) | |
edge_types.append(4) | |
if i - 1 >= 0: | |
source_id.append(word_idx) | |
dest_id.append(word_idx - 1) | |
edge_types.append(5) | |
# optional: get recurrent edges | |
if enable_recur_edge: | |
recur_source, recur_dest = list(range(1, num_word + 1)), list(range(1, num_word + 1)) | |
source_id += recur_source | |
dest_id += recur_dest | |
edge_types += [2] * len(recur_source) | |
# add <BOS> and <EOS> | |
source_id += [0, num_word + 1, 1, num_word] | |
dest_id += [1, num_word, 0, num_word + 1] | |
edge_types += [4, 4, 5, 5] # 4 represents sequentially forward, 5 is sequential backward | |
edges = (torch.LongTensor(source_id), torch.LongTensor(dest_id)) | |
dgl_graph = dgl.graph(edges) | |
assert dgl_graph.num_edges() == len(edge_types) | |
return dgl_graph, torch.LongTensor(edge_types) | |
def _parse_en(self, clean_sentence, enable_backward_edge=True, enable_recur_edge=True, | |
enable_inter_sentence_edge=True, sequential_edge=False, consider_bos_for_index=True): | |
""" | |
clean_sentence: <str>, each word or punctuation should be separated by one blank. | |
""" | |
edge_types = [] # required for gated graph neural network | |
clean_sentence = clean_sentence.strip() | |
if clean_sentence.endswith((" .", " ,", " ;", " :", " ?", " !")): | |
clean_sentence = clean_sentence[:-2] | |
if clean_sentence.startswith(". "): | |
clean_sentence = clean_sentence[2:] | |
parser_out = self.stanza_parser(clean_sentence) | |
if len(parser_out.sentences) > 5: | |
print("Detect more than 5 input sentence! pls check whether the sentence is too long!") | |
print(clean_sentence) | |
source_id, dest_id = [], [] | |
sentences_heads = [] | |
word_id_offset = 0 | |
# get forward edges | |
for s in parser_out.sentences: | |
for w in s.words: | |
w_idx = w.id + word_id_offset # it starts from 1, just same as binarizer | |
w_dest_idx = w.head + word_id_offset | |
if w.head == 0: | |
sentences_heads.append(w_idx) | |
continue | |
source_id.append(w_idx) | |
dest_id.append(w_dest_idx) | |
word_id_offset += len(s.words) | |
num_word = word_id_offset | |
edge_types += [0] * len(source_id) | |
# optional: get backward edges | |
if enable_backward_edge: | |
back_source, back_dest = deepcopy(dest_id), deepcopy(source_id) | |
source_id += back_source | |
dest_id += back_dest | |
edge_types += [1] * len(back_source) | |
# optional: get recurrent edges | |
if enable_recur_edge: | |
recur_source, recur_dest = list(range(1, num_word + 1)), list(range(1, num_word + 1)) | |
source_id += recur_source | |
dest_id += recur_dest | |
edge_types += [2] * len(recur_source) | |
# optional: get inter-sentence edges if num_sentences > 1 | |
inter_sentence_source, inter_sentence_dest = [], [] | |
if enable_inter_sentence_edge and len(sentences_heads) > 1: | |
def get_full_graph_edges(nodes): | |
tmp_edges = [] | |
for i, node_i in enumerate(nodes): | |
for j, node_j in enumerate(nodes): | |
if i == j: | |
continue | |
tmp_edges.append((node_i, node_j)) | |
return tmp_edges | |
tmp_edges = get_full_graph_edges(sentences_heads) | |
for (source, dest) in tmp_edges: | |
inter_sentence_source.append(source) | |
inter_sentence_dest.append(dest) | |
source_id += inter_sentence_source | |
dest_id += inter_sentence_dest | |
edge_types += [3] * len(inter_sentence_source) | |
# add <BOS> and <EOS> | |
source_id += [0, num_word + 1, 1, num_word] | |
dest_id += [1, num_word, 0, num_word + 1] | |
edge_types += [4, 4, 5, 5] # 4 represents sequentially forward, 5 is sequential backward | |
# optional: sequential edge | |
if sequential_edge: | |
seq_source, seq_dest = list(range(1, num_word)) + list(range(num_word, 0, -1)), \ | |
list(range(2, num_word + 1)) + list(range(num_word - 1, -1, -1)) | |
source_id += seq_source | |
dest_id += seq_dest | |
edge_types += [4] * (num_word - 1) + [5] * (num_word - 1) | |
if consider_bos_for_index: | |
edges = (torch.LongTensor(source_id), torch.LongTensor(dest_id)) | |
else: | |
edges = (torch.LongTensor(source_id) - 1, torch.LongTensor(dest_id) - 1) | |
dgl_graph = dgl.graph(edges) | |
assert dgl_graph.num_edges() == len(edge_types) | |
return dgl_graph, torch.LongTensor(edge_types) | |
def plot_dgl_sentence_graph(dgl_graph, labels): | |
""" | |
labels = {idx: word for idx,word in enumerate(sentence.split(" ")) } | |
""" | |
import matplotlib.pyplot as plt | |
nx_graph = dgl_graph.to_networkx() | |
pos = nx.random_layout(nx_graph) | |
nx.draw(nx_graph, pos, with_labels=False) | |
nx.draw_networkx_labels(nx_graph, pos, labels) | |
plt.show() | |
if __name__ == '__main__': | |
# Unit Test for Chinese Graph Builder | |
parser = Sentence2GraphParser("zh") | |
text1 = '宝马配挂跛骡鞍,貂蝉怨枕董翁榻.' | |
words = ['<BOS>', '宝', '马', '配', '挂', '跛', '骡', '鞍', ',', '貂', '蝉', '怨', '枕', '董', '翁', '榻', '<EOS>'] | |
ph_words = ['<BOS>', 'b_ao3_|', 'm_a3_#', 'p_ei4_|', 'g_ua4_#', 'b_o3_#', 'l_uo2_|', 'an1', ',', 'd_iao1_|', | |
'ch_an2_#', 'van4_#', 'zh_en3_#', 'd_ong3_|', 'ueng1_#', 't_a4', '<EOS>'] | |
graph1, etypes1 = parser.parse(text1, words, ph_words) | |
plot_dgl_sentence_graph(graph1, {i: w for i, w in enumerate(ph_words)}) | |
# Unit Test for English Graph Builder | |
parser = Sentence2GraphParser("en") | |
text2 = "I love you . You love me . Mixue ice-scream and tea ." | |
graph2, etypes2 = parser.parse(text2) | |
plot_dgl_sentence_graph(graph2, {i: w for i, w in enumerate(("<BOS> " + text2 + " <EOS>").split(" "))}) | |