Thomas De Decker commited on
Commit
8d04b0f
1 Parent(s): f2f4fc6

Fix extraction pipeline

Browse files
app.py CHANGED
@@ -12,7 +12,7 @@ from pipelines.keyphrase_generation_pipeline import KeyphraseGenerationPipeline
12
  @st.cache(allow_output_mutation=True, show_spinner=False)
13
  def load_pipeline(chosen_model):
14
  if "keyphrase-extraction" in chosen_model:
15
- return KeyphraseExtractionPipeline(chosen_model, truncation=True)
16
  elif "keyphrase-generation" in chosen_model:
17
  return KeyphraseGenerationPipeline(chosen_model, truncation=True)
18
 
 
12
  @st.cache(allow_output_mutation=True, show_spinner=False)
13
  def load_pipeline(chosen_model):
14
  if "keyphrase-extraction" in chosen_model:
15
+ return KeyphraseExtractionPipeline(chosen_model)
16
  elif "keyphrase-generation" in chosen_model:
17
  return KeyphraseGenerationPipeline(chosen_model, truncation=True)
18
 
pipelines/keyphrase_generation_pipeline.py CHANGED
@@ -11,7 +11,7 @@ class KeyphraseGenerationPipeline(Text2TextGenerationPipeline):
11
  def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs):
12
  super().__init__(
13
  model=AutoModelForSeq2SeqLM.from_pretrained(model),
14
- tokenizer=AutoTokenizer.from_pretrained(model),
15
  *args,
16
  **kwargs
17
  )
 
11
  def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs):
12
  super().__init__(
13
  model=AutoModelForSeq2SeqLM.from_pretrained(model),
14
+ tokenizer=AutoTokenizer.from_pretrained(model, truncation=True),
15
  *args,
16
  **kwargs
17
  )