Arjav commited on
Commit
9de349e
1 Parent(s): 49b1ffe

Updated app to use fine tuned longformer

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -1,21 +1,23 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import PegasusTokenizer, PegasusForConditionalGeneration
4
 
5
 
6
  def summarize(Terms):
7
- tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-billsum')
8
- model = PegasusForConditionalGeneration.from_pretrained(
9
- "arjav/TOS-Pegasus")
10
  input_tokenized = tokenizer.encode(
11
- Terms, return_tensors='pt', max_length=1024, truncation=True)
 
12
  summary_ids = model.generate(input_tokenized,
13
  num_beams=9,
14
  no_repeat_ngram_size=3,
15
  length_penalty=2.0,
16
- min_length= 150,
17
- max_length= 200,
18
  early_stopping=True)
 
19
  summary = [tokenizer.decode(g, skip_special_tokens=True,
20
  clean_up_tokenization_spaces=False) for g in summary_ids][0]
21
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, LEDForConditionalGeneration
4
 
5
 
6
  def summarize(Terms):
7
+ tokenizer = AutoTokenizer.from_pretrained('allenai/led-base-16384')
8
+ model = LEDForConditionalGeneration.from_pretrained("Arjav/TOS-Longformer")
9
+
10
  input_tokenized = tokenizer.encode(
11
+ Terms, return_tensors='pt', max_length=8192, truncation=True)
12
+
13
  summary_ids = model.generate(input_tokenized,
14
  num_beams=9,
15
  no_repeat_ngram_size=3,
16
  length_penalty=2.0,
17
+ min_length= 200,
18
+ max_length= 400,
19
  early_stopping=True)
20
+
21
  summary = [tokenizer.decode(g, skip_special_tokens=True,
22
  clean_up_tokenization_spaces=False) for g in summary_ids][0]
23