bmorphism mamta commited on
Commit
955d3fb
0 Parent(s):

Duplicate from Babelscape/rebel-demo

Browse files

Co-authored-by: Mamta Narang <[email protected]>

Files changed (4) hide show
  1. .gitattributes +27 -0
  2. README.md +38 -0
  3. app.py +104 -0
  4. requirements.txt +3 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Rebel Demo
3
+ emoji: 🌍
4
+ colorFrom: purple
5
+ colorTo: pink
6
+ sdk: streamlit
7
+ app_file: app.py
8
+ pinned: false
9
+ duplicated_from: Babelscape/rebel-demo
10
+ ---
11
+
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio` or `streamlit`
28
+
29
+ `sdk_version` : _string_
30
+ Only applicable for `streamlit` SDK.
31
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
+
33
+ `app_file`: _string_
34
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
35
+ Path is relative to the root of the repository.
36
+
37
+ `pinned`: _boolean_
38
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from datasets import load_dataset
3
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
4
+ from time import time
5
+ import torch
6
+
7
+ @st.cache(
8
+ allow_output_mutation=True,
9
+ hash_funcs={
10
+ AutoTokenizer: lambda x: None,
11
+ AutoModelForSeq2SeqLM: lambda x: None,
12
+ },
13
+ suppress_st_warning=True
14
+ )
15
+ def load_models():
16
+ st_time = time()
17
+ tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
18
+ print("+++++ loading Model", time() - st_time)
19
+ model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
20
+ if torch.cuda.is_available():
21
+ _ = model.to("cuda:0") # comment if no GPU available
22
+ _ = model.eval()
23
+ print("+++++ loaded model", time() - st_time)
24
+ dataset = load_dataset('Babelscape/rebel-dataset', split="validation", streaming=True)
25
+ dataset = [example for example in dataset.take(1001)]
26
+ return (tokenizer, model, dataset)
27
+
28
+ def extract_triplets(text):
29
+ triplets = []
30
+ relation, subject, relation, object_ = '', '', '', ''
31
+ text = text.strip()
32
+ current = 'x'
33
+ for token in text.split():
34
+ if token == "<triplet>":
35
+ current = 't'
36
+ if relation != '':
37
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
38
+ relation = ''
39
+ subject = ''
40
+ elif token == "<subj>":
41
+ current = 's'
42
+ if relation != '':
43
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
44
+ object_ = ''
45
+ elif token == "<obj>":
46
+ current = 'o'
47
+ relation = ''
48
+ else:
49
+ if current == 't':
50
+ subject += ' ' + token
51
+ elif current == 's':
52
+ object_ += ' ' + token
53
+ elif current == 'o':
54
+ relation += ' ' + token
55
+ if subject != '' and relation != '' and object_ != '':
56
+ triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
57
+ return triplets
58
+
59
+ st.markdown("""This is a demo for the Findings of EMNLP 2021 paper [REBEL: Relation Extraction By End-to-end Language generation](https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf). The pre-trained model is able to extract triplets for up to 200 relation types from Wikidata or be used in downstream Relation Extraction task by fine-tuning. Find the model card [here](https://huggingface.co/Babelscape/rebel-large). Read more about it in the [paper](https://aclanthology.org/2021.findings-emnlp.204) and in the original [repository](https://github.com/Babelscape/rebel).""")
60
+
61
+ tokenizer, model, dataset = load_models()
62
+
63
+ agree = st.checkbox('Free input', False)
64
+ if agree:
65
+ text = st.text_input('Input text', 'Punta Cana is a resort town in the municipality of Higüey, in La Altagracia Province, the easternmost province of the Dominican Republic.')
66
+ print(text)
67
+ else:
68
+ dataset_example = st.slider('dataset id', 0, 1000, 0)
69
+ text = dataset[dataset_example]['context']
70
+ length_penalty = st.slider('length_penalty', 0, 10, 0)
71
+ num_beams = st.slider('num_beams', 1, 20, 3)
72
+ num_return_sequences = st.slider('num_return_sequences', 1, num_beams, 2)
73
+
74
+ gen_kwargs = {
75
+ "max_length": 256,
76
+ "length_penalty": length_penalty,
77
+ "num_beams": num_beams,
78
+ "num_return_sequences": num_return_sequences,
79
+ }
80
+
81
+ model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt')
82
+ generated_tokens = model.generate(
83
+ model_inputs["input_ids"].to(model.device),
84
+ attention_mask=model_inputs["attention_mask"].to(model.device),
85
+ **gen_kwargs,
86
+ )
87
+
88
+ decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
89
+ st.title('Input text')
90
+
91
+ st.write(text)
92
+
93
+ if not agree:
94
+ st.title('Silver output')
95
+ st.write(dataset[dataset_example]['triplets'])
96
+ st.write(extract_triplets(dataset[dataset_example]['triplets']))
97
+
98
+ st.title('Prediction text')
99
+ decoded_preds = [text.replace('<s>', '').replace('</s>', '').replace('<pad>', '') for text in decoded_preds]
100
+ st.write(decoded_preds)
101
+
102
+ for idx, sentence in enumerate(decoded_preds):
103
+ st.title(f'Prediction triplets sentence {idx}')
104
+ st.write(extract_triplets(sentence))
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ streamlit
3
+ transformers