Update Summarizer/Extractive.py
Browse files- Summarizer/Extractive.py +26 -0
Summarizer/Extractive.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import nltk
|
|
|
2 |
from summarizer import Summarizer
|
3 |
from sumy.nlp.tokenizers import Tokenizer
|
4 |
from sumy.summarizers.lsa import LsaSummarizer
|
@@ -37,6 +38,31 @@ def summarize(file, model):
|
|
37 |
skip_special_tokens=True,
|
38 |
clean_up_tokenization_spaces=False)
|
39 |
summary = summary[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
elif model == "TextRank":
|
42 |
summary = extractive(LexRankSummarizer(), doc)
|
|
|
1 |
import nltk
|
2 |
+
import torch
|
3 |
from summarizer import Summarizer
|
4 |
from sumy.nlp.tokenizers import Tokenizer
|
5 |
from sumy.summarizers.lsa import LsaSummarizer
|
|
|
38 |
skip_special_tokens=True,
|
39 |
clean_up_tokenization_spaces=False)
|
40 |
summary = summary[0]
|
41 |
+
elif model == "LEDBill":
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained("d0r1h/LEDBill")
|
43 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/LEDBill", return_dict_in_generate=True)
|
44 |
+
|
45 |
+
input_ids = tokenizer(doc, return_tensors="pt").input_ids
|
46 |
+
global_attention_mask = torch.zeros_like(input_ids)
|
47 |
+
global_attention_mask[:, 0] = 1
|
48 |
+
|
49 |
+
sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
|
50 |
+
summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
51 |
+
|
52 |
+
summary = summary[0]
|
53 |
+
|
54 |
+
elif model == "ILC":
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained("d0r1h/led-base-ilc")
|
56 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/led-base-ilc", return_dict_in_generate=True)
|
57 |
+
|
58 |
+
input_ids = tokenizer(doc, return_tensors="pt").input_ids
|
59 |
+
global_attention_mask = torch.zeros_like(input_ids)
|
60 |
+
global_attention_mask[:, 0] = 1
|
61 |
+
|
62 |
+
sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
|
63 |
+
summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
64 |
+
|
65 |
+
summary = summary[0]
|
66 |
|
67 |
elif model == "TextRank":
|
68 |
summary = extractive(LexRankSummarizer(), doc)
|