d0r1h commited on
Commit
e76e540
1 Parent(s): 2c55a80

Update Summarizer/Extractive.py

Browse files
Files changed (1) hide show
  1. Summarizer/Extractive.py +14 -0
Summarizer/Extractive.py CHANGED
@@ -63,6 +63,20 @@ def summarize(file, model):
63
  summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
64
 
65
  summary = summary[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  elif model == "TextRank":
68
  summary = extractive(LexRankSummarizer(), doc)
 
63
  summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
64
 
65
  summary = summary[0]
66
+ elif model == "Distill":
67
+ checkpoint = "sshleifer/distill-pegasus-cnn-16-4"
68
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
69
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
70
+ inputs = tokenizer(doc,
71
+ max_length=1024,
72
+ truncation=True,
73
+ return_tensors="pt")
74
+
75
+ summary_ids = model.generate(inputs["input_ids"])
76
+ summary = tokenizer.batch_decode(summary_ids,
77
+ skip_special_tokens=True,
78
+ clean_up_tokenization_spaces=False)
79
+ summary = summary[0]
80
 
81
  elif model == "TextRank":
82
  summary = extractive(LexRankSummarizer(), doc)