from transformers import ( Text2TextGenerationPipeline, AutoModelForSeq2SeqLM, AutoTokenizer, ) import string class KeyphraseGenerationPipeline(Text2TextGenerationPipeline): def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs): super().__init__( model=AutoModelForSeq2SeqLM.from_pretrained(model), tokenizer=AutoTokenizer.from_pretrained(model), *args, **kwargs ) self.keyphrase_sep_token = keyphrase_sep_token def postprocess(self, model_outputs): results = super().postprocess(model_outputs=model_outputs) return [ [ keyphrase.strip().translate(str.maketrans('', '', string.punctuation)) for keyphrase in result.get("generated_text").split( self.keyphrase_sep_token ) if keyphrase.translate(str.maketrans('', '', string.punctuation)) != "" ] for result in results ][0]