keyphrase-extraction / pipelines /keyphrase_generation_pipeline.py
Thomas De Decker
Add max input length
f2f4fc6
raw
history blame
1.03 kB
import string
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
Text2TextGenerationPipeline,
)
class KeyphraseGenerationPipeline(Text2TextGenerationPipeline):
def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs):
super().__init__(
model=AutoModelForSeq2SeqLM.from_pretrained(model),
tokenizer=AutoTokenizer.from_pretrained(model),
*args,
**kwargs
)
self.keyphrase_sep_token = keyphrase_sep_token
def postprocess(self, model_outputs):
results = super().postprocess(model_outputs=model_outputs)
return [
[
keyphrase.strip().translate(str.maketrans("", "", string.punctuation))
for keyphrase in result.get("generated_text").split(
self.keyphrase_sep_token
)
if keyphrase.translate(str.maketrans("", "", string.punctuation)) != ""
]
for result in results
][0]