PereLluis13 commited on
Commit
e6f2745
1 Parent(s): 1678277

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from datasets import load_dataset
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ from time import time
5
+ import torch
6
+
7
+ @st.cache(
8
+ allow_output_mutation=True,
9
+ hash_funcs={
10
+ AutoTokenizer: lambda x: None,
11
+ AutoModelForSeq2SeqLM: lambda x: None,
12
+ },
13
+ suppress_st_warning=True
14
+ )
15
+ def load_models(lan):
16
+ st_time = time()
17
+ tokenizer = AutoTokenizer.from_pretrained("Babelscape/mrebel-large", src_lang=_Tokens[lan], "tgt_lang": "tp_XX")
18
+ print("+++++ loading Model", time() - st_time)
19
+ model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/mrebel-large")
20
+ if torch.cuda.is_available():
21
+ _ = model.to("cuda:0") # comment if no GPU available
22
+ _ = model.eval()
23
+ print("+++++ loaded model", time() - st_time)
24
+ dataset = load_dataset('Babelscape/SREDFM', lan, split="validation", streaming=True)
25
+ dataset = [example for example in dataset.take(1001)]
26
+ return (tokenizer, model, dataset)
27
+
28
+ def extract_triplets_typed(text):
29
+ triplets = []
30
+ relation = ''
31
+ text = text.strip()
32
+ current = 'x'
33
+ subject, relation, object_, object_type, subject_type = '','','','',''
34
+
35
+ for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").replace("tp_XX", "").replace("__en__", "").split():
36
+ if token == "<triplet>" or token == "<relation>":
37
+ current = 't'
38
+ if relation != '':
39
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
40
+ relation = ''
41
+ subject = ''
42
+ elif token.startswith("<") and token.endswith(">"):
43
+ if current == 't' or current == 'o':
44
+ current = 's'
45
+ if relation != '':
46
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
47
+ object_ = ''
48
+ subject_type = token[1:-1]
49
+ else:
50
+ current = 'o'
51
+ object_type = token[1:-1]
52
+ relation = ''
53
+ else:
54
+ if current == 't':
55
+ subject += ' ' + token
56
+ elif current == 's':
57
+ object_ += ' ' + token
58
+ elif current == 'o':
59
+ relation += ' ' + token
60
+ if subject != '' and relation != '' and object_ != '' and object_type != '' and subject_type != '':
61
+ triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
62
+ return triplets
63
+
64
+ st.markdown("""This is a demo for the Findings of EMNLP 2021 paper [REBEL: Relation Extraction By End-to-end Language generation](https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf). The pre-trained model is able to extract triplets for up to 200 relation types from Wikidata or be used in downstream Relation Extraction task by fine-tuning. Find the model card [here](https://huggingface.co/Babelscape/rebel-large). Read more about it in the [paper](https://aclanthology.org/2021.findings-emnlp.204) and in the original [repository](https://github.com/Babelscape/rebel).""")
65
+
66
+ lan = st.selectbox(
67
+ 'Select a Language',
68
+ ('ar', 'ca', 'de', 'el', 'en', 'es', 'fr', 'hi', 'it', 'ja', 'ko', 'nl', 'pl', 'pt', 'ru', 'sv', 'vi', 'zh'))
69
+
70
+ _Tokens = {'en': 'en_XX', 'de': 'de_DE', 'ca': 'ca_XX', 'ar': 'ar_AR', 'el': 'el_EL', 'it': 'it_IT', 'ja': 'ja_XX', 'ko': 'ko_KR', 'hi': 'hi_IN', 'pt': 'pt_XX', 'ru': 'ru_RU', 'pl': 'pl_PL', 'zh': 'zh_CN', 'fr': 'fr_XX', 'vi': 'vi_VN', 'sv':'sv_SE'}
71
+
72
+ tokenizer, model, dataset = load_models(lan)
73
+
74
+ agree = st.checkbox('Free input', False)
75
+ if agree:
76
+ text = st.text_input('Input text', 'Els Red Hot Chili Peppers es van formar a Los Angeles per Kiedis, Flea, el guitarrista Hillel Slovak i el bateria Jack Irons.')
77
+ print(text)
78
+ else:
79
+ dataset_example = st.slider('dataset id', 0, 1000, 0)
80
+ text = dataset[dataset_example]['context']
81
+ length_penalty = st.slider('length_penalty', 0, 10, 0)
82
+ num_beams = st.slider('num_beams', 1, 20, 3)
83
+ num_return_sequences = st.slider('num_return_sequences', 1, num_beams, 2)
84
+
85
+ gen_kwargs = {
86
+ "max_length": 256,
87
+ "length_penalty": length_penalty,
88
+ "num_beams": num_beams,
89
+ "num_return_sequences": num_return_sequences,
90
+ "forced_bos_token_id": None,
91
+ }
92
+
93
+ model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt')
94
+ generated_tokens = model.generate(
95
+ model_inputs["input_ids"].to(model.device),
96
+ attention_mask=model_inputs["attention_mask"].to(model.device),
97
+ decoder_start_token_id = tokenizer.convert_tokens_to_ids("tp_XX"),
98
+ **gen_kwargs,
99
+ )
100
+
101
+ decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
102
+ st.title('Input text')
103
+
104
+ st.write(text)
105
+
106
+ if not agree:
107
+ st.title('Silver output')
108
+ st.write(dataset[dataset_example]['triplets'])
109
+ st.write(extract_triplets_typed(dataset[dataset_example]['triplets']))
110
+
111
+ st.title('Prediction text')
112
+ decoded_preds = [text.replace('<s>', '').replace('</s>', '').replace('<pad>', '') for text in decoded_preds]
113
+ st.write(decoded_preds)
114
+
115
+ for idx, sentence in enumerate(decoded_preds):
116
+ st.title(f'Prediction triplets sentence {idx}')
117
+ st.write(extract_triplets_typed(sentence))