hellopahe commited on
Commit
6a0cb69
β€’
1 Parent(s): c215129

add custom siblings

Browse files
lex_rank.py CHANGED
@@ -1,9 +1,11 @@
 
 
1
  import numpy, nltk
2
  nltk.download('punkt')
3
 
4
 
5
  from harvesttext import HarvestText
6
- from lex_rank_util import degree_centrality_scores
7
  from sentence_transformers import SentenceTransformer, util
8
 
9
 
@@ -12,7 +14,7 @@ class LexRank(object):
12
  self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
13
  self.ht = HarvestText()
14
 
15
- def find_central(self, content: str, num=10):
16
  if self.contains_chinese(content):
17
  sentences = self.ht.cut_sentences(content)
18
  else:
@@ -33,7 +35,7 @@ class LexRank(object):
33
  for index in most_central_sentence_indices:
34
  if num < 0:
35
  break
36
- res.append(sentences[index])
37
  num -= 1
38
  return res
39
 
@@ -42,3 +44,5 @@ class LexRank(object):
42
  if '\u4e00' <= _char <= '\u9fa5':
43
  return True
44
  return False
 
 
 
1
+ import math
2
+
3
  import numpy, nltk
4
  nltk.download('punkt')
5
 
6
 
7
  from harvesttext import HarvestText
8
+ from lex_rank_util import degree_centrality_scores, find_siblings
9
  from sentence_transformers import SentenceTransformer, util
10
 
11
 
 
14
  self.model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
15
  self.ht = HarvestText()
16
 
17
+ def find_central(self, content: str, num=10, siblings=0):
18
  if self.contains_chinese(content):
19
  sentences = self.ht.cut_sentences(content)
20
  else:
 
35
  for index in most_central_sentence_indices:
36
  if num < 0:
37
  break
38
+ res.append(find_siblings(sentences, index, siblings)[1])
39
  num -= 1
40
  return res
41
 
 
44
  if '\u4e00' <= _char <= '\u9fa5':
45
  return True
46
  return False
47
+
48
+
lex_rank_L12.py CHANGED
@@ -3,7 +3,7 @@ nltk.download('punkt')
3
 
4
 
5
  from harvesttext import HarvestText
6
- from lex_rank_util import degree_centrality_scores
7
  from sentence_transformers import SentenceTransformer, util
8
 
9
 
@@ -12,7 +12,7 @@ class LexRankL12(object):
12
  self.model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
13
  self.ht = HarvestText()
14
 
15
- def find_central(self, content: str, num=10):
16
  if self.contains_chinese(content):
17
  sentences = self.ht.cut_sentences(content)
18
  else:
@@ -33,7 +33,7 @@ class LexRankL12(object):
33
  for index in most_central_sentence_indices:
34
  if num < 0:
35
  break
36
- res.append(sentences[index])
37
  num -= 1
38
  return res
39
 
 
3
 
4
 
5
  from harvesttext import HarvestText
6
+ from lex_rank_util import degree_centrality_scores, find_siblings
7
  from sentence_transformers import SentenceTransformer, util
8
 
9
 
 
12
  self.model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
13
  self.ht = HarvestText()
14
 
15
+ def find_central(self, content: str, num=10, siblings=0):
16
  if self.contains_chinese(content):
17
  sentences = self.ht.cut_sentences(content)
18
  else:
 
33
  for index in most_central_sentence_indices:
34
  if num < 0:
35
  break
36
+ res.append(find_siblings(sentences, index, siblings)[1])
37
  num -= 1
38
  return res
39
 
lex_rank_distiluse_v1.py β†’ lex_rank_text2vec_v1.py RENAMED
@@ -3,16 +3,16 @@ nltk.download('punkt')
3
 
4
 
5
  from harvesttext import HarvestText
6
- from lex_rank_util import degree_centrality_scores
7
  from sentence_transformers import SentenceTransformer, util
8
 
9
 
10
- class LexRankDistiluseV1(object):
11
  def __init__(self):
12
- self.model = SentenceTransformer('distiluse-base-multilingual-cased-v1')
13
  self.ht = HarvestText()
14
 
15
- def find_central(self, content: str, num=10):
16
  if self.contains_chinese(content):
17
  sentences = self.ht.cut_sentences(content)
18
  else:
@@ -33,7 +33,7 @@ class LexRankDistiluseV1(object):
33
  for index in most_central_sentence_indices:
34
  if num < 0:
35
  break
36
- res.append(sentences[index])
37
  num -= 1
38
  return res
39
 
 
3
 
4
 
5
  from harvesttext import HarvestText
6
+ from lex_rank_util import degree_centrality_scores, find_siblings
7
  from sentence_transformers import SentenceTransformer, util
8
 
9
 
10
+ class LexRankText2VecV1(object):
11
  def __init__(self):
12
+ self.model = SentenceTransformer('shibing624/text2vec-base-chinese-paraphrase')
13
  self.ht = HarvestText()
14
 
15
+ def find_central(self, content: str, num=10, siblings=0):
16
  if self.contains_chinese(content):
17
  sentences = self.ht.cut_sentences(content)
18
  else:
 
33
  for index in most_central_sentence_indices:
34
  if num < 0:
35
  break
36
+ res.append(find_siblings(sentences, index, siblings)[1])
37
  num -= 1
38
  return res
39
 
lex_rank_util.py CHANGED
@@ -6,7 +6,7 @@ Source: https://github.com/crabcamp/lexrank/tree/dev
6
  import numpy as np
7
  from scipy.sparse.csgraph import connected_components
8
  from scipy.special import softmax
9
- import logging
10
 
11
  logger = logging.getLogger(__name__)
12
 
@@ -121,4 +121,12 @@ def stationary_distribution(
121
  if normalized:
122
  distribution /= n_1
123
 
124
- return distribution
 
 
 
 
 
 
 
 
 
6
  import numpy as np
7
  from scipy.sparse.csgraph import connected_components
8
  from scipy.special import softmax
9
+ import logging, math
10
 
11
  logger = logging.getLogger(__name__)
12
 
 
121
  if normalized:
122
  distribution /= n_1
123
 
124
+ return distribution
125
+
126
+
127
+ def find_siblings(sentences: [str], idx: int, siblings: int) -> (int, str):
128
+ if not siblings < math.ceil(len(sentences) / 2):
129
+ return -1, "siblings too large, try some value smaller."
130
+ head = max(idx - siblings, 0)
131
+ tail = min(idx + siblings + 1, len(sentences))
132
+ return 0, "".join(sentences[head:tail])