Thomas De Decker commited on
Commit
f2f4fc6
β€’
1 Parent(s): 099b1c5

Add max input length

Browse files
app.py CHANGED
@@ -12,9 +12,9 @@ from pipelines.keyphrase_generation_pipeline import KeyphraseGenerationPipeline
12
  @st.cache(allow_output_mutation=True, show_spinner=False)
13
  def load_pipeline(chosen_model):
14
  if "keyphrase-extraction" in chosen_model:
15
- return KeyphraseExtractionPipeline(chosen_model)
16
  elif "keyphrase-generation" in chosen_model:
17
- return KeyphraseGenerationPipeline(chosen_model)
18
 
19
 
20
  def extract_keyphrases():
@@ -159,7 +159,12 @@ with st.form("keyphrase-extraction-form"):
159
  )
160
 
161
  st.session_state.input_text = (
162
- st.text_area("✍ Input", st.session_state.config.get("example_text"), height=250)
 
 
 
 
 
163
  .replace("\n", " ")
164
  .strip()
165
  )
 
12
  @st.cache(allow_output_mutation=True, show_spinner=False)
13
  def load_pipeline(chosen_model):
14
  if "keyphrase-extraction" in chosen_model:
15
+ return KeyphraseExtractionPipeline(chosen_model, truncation=True)
16
  elif "keyphrase-generation" in chosen_model:
17
+ return KeyphraseGenerationPipeline(chosen_model, truncation=True)
18
 
19
 
20
  def extract_keyphrases():
 
159
  )
160
 
161
  st.session_state.input_text = (
162
+ st.text_area(
163
+ "✍ Input",
164
+ st.session_state.config.get("example_text"),
165
+ height=250,
166
+ max_chars=2500,
167
+ )
168
  .replace("\n", " ")
169
  .strip()
170
  )
pipelines/keyphrase_extraction_pipeline.py CHANGED
@@ -11,9 +11,7 @@ class KeyphraseExtractionPipeline(TokenClassificationPipeline):
11
  def __init__(self, model, *args, **kwargs):
12
  super().__init__(
13
  model=AutoModelForTokenClassification.from_pretrained(model),
14
- tokenizer=AutoTokenizer.from_pretrained(
15
- model, truncate=True
16
- ),
17
  *args,
18
  **kwargs
19
  )
 
11
  def __init__(self, model, *args, **kwargs):
12
  super().__init__(
13
  model=AutoModelForTokenClassification.from_pretrained(model),
14
+ tokenizer=AutoTokenizer.from_pretrained(model),
 
 
15
  *args,
16
  **kwargs
17
  )
pipelines/keyphrase_generation_pipeline.py CHANGED
@@ -1,14 +1,17 @@
1
  import string
2
 
3
- from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
4
- Text2TextGenerationPipeline)
 
 
 
5
 
6
 
7
  class KeyphraseGenerationPipeline(Text2TextGenerationPipeline):
8
  def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs):
9
  super().__init__(
10
  model=AutoModelForSeq2SeqLM.from_pretrained(model),
11
- tokenizer=AutoTokenizer.from_pretrained(model, truncate=True),
12
  *args,
13
  **kwargs
14
  )
 
1
  import string
2
 
3
+ from transformers import (
4
+ AutoModelForSeq2SeqLM,
5
+ AutoTokenizer,
6
+ Text2TextGenerationPipeline,
7
+ )
8
 
9
 
10
  class KeyphraseGenerationPipeline(Text2TextGenerationPipeline):
11
  def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs):
12
  super().__init__(
13
  model=AutoModelForSeq2SeqLM.from_pretrained(model),
14
+ tokenizer=AutoTokenizer.from_pretrained(model),
15
  *args,
16
  **kwargs
17
  )