d0r1h commited on
Commit
2c55a80
1 Parent(s): 8df0435

Update Summarizer/Extractive.py

Browse files
Files changed (1) hide show
  1. Summarizer/Extractive.py +26 -0
Summarizer/Extractive.py CHANGED
@@ -1,4 +1,5 @@
1
  import nltk
 
2
  from summarizer import Summarizer
3
  from sumy.nlp.tokenizers import Tokenizer
4
  from sumy.summarizers.lsa import LsaSummarizer
@@ -37,6 +38,31 @@ def summarize(file, model):
37
  skip_special_tokens=True,
38
  clean_up_tokenization_spaces=False)
39
  summary = summary[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  elif model == "TextRank":
42
  summary = extractive(LexRankSummarizer(), doc)
 
1
  import nltk
2
+ import torch
3
  from summarizer import Summarizer
4
  from sumy.nlp.tokenizers import Tokenizer
5
  from sumy.summarizers.lsa import LsaSummarizer
 
38
  skip_special_tokens=True,
39
  clean_up_tokenization_spaces=False)
40
  summary = summary[0]
41
+ elif model == "LEDBill":
42
+ tokenizer = AutoTokenizer.from_pretrained("d0r1h/LEDBill")
43
+ model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/LEDBill", return_dict_in_generate=True)
44
+
45
+ input_ids = tokenizer(doc, return_tensors="pt").input_ids
46
+ global_attention_mask = torch.zeros_like(input_ids)
47
+ global_attention_mask[:, 0] = 1
48
+
49
+ sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
50
+ summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
51
+
52
+ summary = summary[0]
53
+
54
+ elif model == "ILC":
55
+ tokenizer = AutoTokenizer.from_pretrained("d0r1h/led-base-ilc")
56
+ model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/led-base-ilc", return_dict_in_generate=True)
57
+
58
+ input_ids = tokenizer(doc, return_tensors="pt").input_ids
59
+ global_attention_mask = torch.zeros_like(input_ids)
60
+ global_attention_mask[:, 0] = 1
61
+
62
+ sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
63
+ summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
64
+
65
+ summary = summary[0]
66
 
67
  elif model == "TextRank":
68
  summary = extractive(LexRankSummarizer(), doc)