|
--- |
|
language: |
|
- en |
|
widget: |
|
- text: "Punta Cana is a resort town in the municipality of Higuey, in La Altagracia Province, the eastern most province of the Dominican Republic" |
|
tags: |
|
- seq2seq |
|
license: cc-by-nc-sa-4.0 |
|
--- |
|
# REBEL: Relation Extraction By End-to-end Language generation |
|
|
|
This is the model card for the Findings of EMNLP 2021 paper REBEL: Relation Extraction By End-to-end Language generation. We present a new linearization aproach and a reframing of Relation Extraction as a seq2seq task. The paper can be found [here](https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf). If you use the code, please reference this work in your paper: |
|
|
|
@inproceedings{huguet-cabot-navigli-2021-rebel, |
|
title = "REBEL: Relation Extraction By End-to-end Language generation", |
|
author = "Huguet Cabot, Pere-Llu{\'\i}s and |
|
Navigli, Roberto", |
|
booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2021", |
|
month = nov, |
|
year = "2021", |
|
address = "Online and in the Barceló Bávaro Convention Centre, Punta Cana, Dominican Republic", |
|
publisher = "Association for Computational Linguistics", |
|
url = "https://github.com/Babelscape/rebel/blob/main/docs/EMNLP_2021_REBEL__Camera_Ready_.pdf", |
|
} |
|
|
|
The original repository for the paper can be found [here](https://github.com/Babelscape/rebel) |
|
|
|
## Pipeline usage |
|
|
|
```python3 |
|
from transformers import pipeline |
|
|
|
triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large') |
|
# We need to use the tokenizer manually since we need special tokens. |
|
extracted_text = triplet_extractor.tokenizer.batch_decode(triplet_extractor("Punta Cana is a resort town in the municipality of Higuey, in La Altagracia Province, the eastern most province of the Dominican Republic", return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]) |
|
print(extracted_text[0]) |
|
# Function to parse the generated text and extract the triplets |
|
def extract_triplets(text): |
|
triplets = [] |
|
relation, subject, relation, object_ = '', '', '', '' |
|
text = text.strip() |
|
current = 'x' |
|
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split(): |
|
if token == "<triplet>": |
|
current = 't' |
|
if relation != '': |
|
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) |
|
relation = '' |
|
subject = '' |
|
elif token == "<subj>": |
|
current = 's' |
|
if relation != '': |
|
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) |
|
object_ = '' |
|
elif token == "<obj>": |
|
current = 'o' |
|
relation = '' |
|
else: |
|
if current == 't': |
|
subject += ' ' + token |
|
elif current == 's': |
|
object_ += ' ' + token |
|
elif current == 'o': |
|
relation += ' ' + token |
|
if subject != '' and relation != '' and object_ != '': |
|
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) |
|
return triplets |
|
extracted_triplets = extract_triplets(extracted_text[0]) |
|
print(extracted_triplets) |
|
``` |
|
|
|
## Model and Tokenizer using transformers |
|
|
|
```python3 |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
def extract_triplets(text): |
|
triplets = [] |
|
relation, subject, relation, object_ = '', '', '', '' |
|
text = text.strip() |
|
current = 'x' |
|
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split(): |
|
if token == "<triplet>": |
|
current = 't' |
|
if relation != '': |
|
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) |
|
relation = '' |
|
subject = '' |
|
elif token == "<subj>": |
|
current = 's' |
|
if relation != '': |
|
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) |
|
object_ = '' |
|
elif token == "<obj>": |
|
current = 'o' |
|
relation = '' |
|
else: |
|
if current == 't': |
|
subject += ' ' + token |
|
elif current == 's': |
|
object_ += ' ' + token |
|
elif current == 'o': |
|
relation += ' ' + token |
|
if subject != '' and relation != '' and object_ != '': |
|
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()}) |
|
return triplets |
|
|
|
# Load model and tokenizer |
|
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large") |
|
gen_kwargs = { |
|
"max_length": 256, |
|
"length_penalty": 0, |
|
"num_beams": 3, |
|
"num_return_sequences": 3, |
|
} |
|
|
|
# Text to extract triplets from |
|
text = 'Punta Cana is a resort town in the municipality of Higüey, in La Altagracia Province, the easternmost province of the Dominican Republic.' |
|
|
|
# Tokenizer text |
|
model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt') |
|
|
|
# Generate |
|
generated_tokens = model.generate( |
|
model_inputs["input_ids"].to(model.device), |
|
attention_mask=model_inputs["attention_mask"].to(model.device), |
|
**gen_kwargs, |
|
) |
|
|
|
# Extract text |
|
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False) |
|
|
|
# Extract triplets |
|
for idx, sentence in enumerate(decoded_preds): |
|
print(f'Prediction triplets sentence {idx}') |
|
print(extract_triplets(sentence)) |
|
``` |