dev
Browse files- abstruct.py +71 -0
- classification.py +83 -0
- requirements.txt +109 -0
- run.py +1 -0
- util.py +75 -0
abstruct.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 导入所需的库
|
2 |
+
import json
|
3 |
+
import paddlenlp
|
4 |
+
import gensim
|
5 |
+
import sklearn
|
6 |
+
from collections import Counter
|
7 |
+
from gensim import corpora, models, similarities
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def build_corpus(sentences):
|
16 |
+
# 使用paddlenlp提供的预训练词典
|
17 |
+
vocab = paddlenlp.transformers.BertTokenizer.from_pretrained('bert-base-chinese').vocab
|
18 |
+
|
19 |
+
# 创建分词器
|
20 |
+
tokenizer = paddlenlp.data.JiebaTokenizer(vocab)
|
21 |
+
# 对每个句子进行分词,并去除停用词,得到一个二维列表
|
22 |
+
stopwords = [""]
|
23 |
+
words_list = []
|
24 |
+
for sentence in sentences:
|
25 |
+
words = [word for word in tokenizer.cut(sentence) if word not in stopwords]
|
26 |
+
words_list.append(words)
|
27 |
+
# print(words_list)
|
28 |
+
# 将二维列表转换为一维列表
|
29 |
+
words = [word for sentence in words_list for word in sentence]
|
30 |
+
|
31 |
+
dictionary = corpora.Dictionary(words_list)
|
32 |
+
corpus = [dictionary.doc2bow(text) for text in words_list]
|
33 |
+
|
34 |
+
return corpus,dictionary,words_list
|
35 |
+
|
36 |
+
def lda(words_list,sentences,corpus,dictionary,num):
|
37 |
+
lda = gensim.models.ldamodel.LdaModel(corpus=corpus,id2word=dictionary, num_topics=num)
|
38 |
+
|
39 |
+
topics = lda.print_topics(num_topics=num, num_words=10)
|
40 |
+
|
41 |
+
# 根据关键词的匹配度,选择最能代表每个主题的句子,作为中心句
|
42 |
+
|
43 |
+
central_sentences = []
|
44 |
+
for topic in topics:
|
45 |
+
topic_id, topic_words = topic
|
46 |
+
topic_words = [word.split("*")[1].strip('"') for word in topic_words.split("+")]
|
47 |
+
max_score = 0
|
48 |
+
candidates = [] # 存储候选中心句
|
49 |
+
for sentence, words in zip(sentences, words_list):
|
50 |
+
score = 0
|
51 |
+
for word in words:
|
52 |
+
if word in topic_words:
|
53 |
+
score += 1
|
54 |
+
if score > max_score:
|
55 |
+
max_score = score
|
56 |
+
candidates = [sentence] # 如果找到更高的匹配度,更新候选列表
|
57 |
+
elif score == max_score:
|
58 |
+
candidates.append(sentence) # 如果匹配度相同,添加到候选列表
|
59 |
+
for candidate in candidates: # 遍历候选列表
|
60 |
+
if candidate not in central_sentences: # 检查是否已经存在相同的句子
|
61 |
+
central_sentence = candidate # 如果不存在,选择为中心句
|
62 |
+
central_sentences.append(central_sentence)
|
63 |
+
break # 跳出循环
|
64 |
+
|
65 |
+
return central_sentences
|
66 |
+
|
67 |
+
|
68 |
+
def abstruct_main(sentences,num):
|
69 |
+
corpus,dictionary,words_list = build_corpus(sentences)
|
70 |
+
central_sentences= lda(words_list, sentences, corpus, dictionary,num)
|
71 |
+
return central_sentences
|
classification.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gensim
|
2 |
+
import numpy as np
|
3 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
4 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
5 |
+
from transformers import AutoTokenizer, AutoModel
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def classify_by_topic(articles, central_topics):
|
10 |
+
|
11 |
+
# 计算每篇文章与每个中心主题的相似度,返回一个矩阵
|
12 |
+
def compute_similarity(articles, central_topics):
|
13 |
+
|
14 |
+
model = AutoModel.from_pretrained("distilbert-base-multilingual-cased")
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
16 |
+
"distilbert-base-multilingual-cased")
|
17 |
+
|
18 |
+
# 将一个句子转换为一个向量
|
19 |
+
def sentence_to_vector(sentence, context):
|
20 |
+
# 分词并添加特殊标记
|
21 |
+
sentence = context[0]+context[1]+sentence*4+context[2]+context[3]
|
22 |
+
tokens = tokenizer.encode_plus(
|
23 |
+
sentence, add_special_tokens=True, return_tensors="pt")
|
24 |
+
# 获取每个词的隐藏状态向量
|
25 |
+
outputs = model(**tokens)
|
26 |
+
hidden_states = outputs.last_hidden_state
|
27 |
+
# 计算平均向量作为句子向量
|
28 |
+
vector = np.squeeze(torch.mean(
|
29 |
+
hidden_states, dim=1).detach().numpy()) # a 1 x d tensor
|
30 |
+
return vector
|
31 |
+
|
32 |
+
# 获取一个句子的上下文
|
33 |
+
def get_context(sentences, index):
|
34 |
+
if index == 0:
|
35 |
+
prev_sentence = ""
|
36 |
+
pprev_sentence = ""
|
37 |
+
elif index == 1:
|
38 |
+
prev_sentence = sentences[index-1]
|
39 |
+
pprev_sentence = ""
|
40 |
+
else:
|
41 |
+
prev_sentence = sentences[index-1]
|
42 |
+
pprev_sentence = sentences[index-2]
|
43 |
+
if index == len(sentences) - 1:
|
44 |
+
next_sentence = ""
|
45 |
+
nnext_sentence = ""
|
46 |
+
elif index == len(sentences) - 2:
|
47 |
+
next_sentence = sentences[index+1]
|
48 |
+
nnext_sentence = ""
|
49 |
+
else:
|
50 |
+
next_sentence = sentences[index+1]
|
51 |
+
nnext_sentence = sentences[index+2]
|
52 |
+
return (pprev_sentence, prev_sentence, next_sentence, nnext_sentence)
|
53 |
+
|
54 |
+
# 将每个文章句子和每个中心句子转换为向量
|
55 |
+
doc_vectors = [sentence_to_vector(sentence, get_context(
|
56 |
+
articles, i)) for i, sentence in enumerate(articles)]
|
57 |
+
topic_vectors = [sentence_to_vector(sentence, get_context(
|
58 |
+
central_topics, i)) for i, sentence in enumerate(central_topics)]
|
59 |
+
# 计算每个文章句子和每个中心句子之间的余弦相似度矩阵
|
60 |
+
cos_sim_matrix = cosine_similarity(doc_vectors, topic_vectors)
|
61 |
+
|
62 |
+
# print(cos_sim_matrix)
|
63 |
+
return cos_sim_matrix
|
64 |
+
|
65 |
+
# 按照相似度矩阵分类文章,返回一个列表
|
66 |
+
def group_by_topic(articles, central_topics, similarity_matrix):
|
67 |
+
group = []
|
68 |
+
original_articles = articles.copy() # 保存一份原始的文章列表
|
69 |
+
# 用原始的文章列表替换预处理后的文章列表
|
70 |
+
for article, similarity in zip(original_articles, similarity_matrix):
|
71 |
+
max_similarity = max(similarity) # 取最高的相似度值
|
72 |
+
max_index = similarity.tolist().index(max_similarity) # 取最高相似度值对应的索引
|
73 |
+
# print(max_similarity,max_index )
|
74 |
+
group.append((article, central_topics[max_index]))
|
75 |
+
|
76 |
+
return group
|
77 |
+
|
78 |
+
# 实现分类功能
|
79 |
+
similarity_matrix = compute_similarity(articles, central_topics)
|
80 |
+
groups = group_by_topic(articles, central_topics, similarity_matrix)
|
81 |
+
|
82 |
+
# 返回分类后的列表
|
83 |
+
return groups
|
requirements.txt
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.8.4
|
2 |
+
aiosignal==1.3.1
|
3 |
+
anyio==3.7.0
|
4 |
+
astor==0.8.1
|
5 |
+
async-timeout==4.0.2
|
6 |
+
attrs==23.1.0
|
7 |
+
Babel==2.12.1
|
8 |
+
backoff==2.2.1
|
9 |
+
bce-python-sdk==0.8.83
|
10 |
+
blinker==1.6.2
|
11 |
+
certifi==2023.5.7
|
12 |
+
charset-normalizer==3.1.0
|
13 |
+
click==8.1.3
|
14 |
+
cmake==3.26.3
|
15 |
+
colorama==0.4.6
|
16 |
+
colorlog==6.7.0
|
17 |
+
contourpy==1.0.7
|
18 |
+
cycler==0.11.0
|
19 |
+
datasets==2.12.0
|
20 |
+
decorator==5.1.1
|
21 |
+
dill==0.3.4
|
22 |
+
exceptiongroup==1.1.1
|
23 |
+
fastapi==0.95.2
|
24 |
+
filelock==3.12.0
|
25 |
+
Flask==2.3.2
|
26 |
+
Flask-Babel==2.0.0
|
27 |
+
fonttools==4.39.4
|
28 |
+
frozenlist==1.3.3
|
29 |
+
fsspec==2023.5.0
|
30 |
+
future==0.18.3
|
31 |
+
gensim==4.3.1
|
32 |
+
h11==0.14.0
|
33 |
+
huggingface-hub==0.14.1
|
34 |
+
idna==3.4
|
35 |
+
importlib-metadata==6.6.0
|
36 |
+
importlib-resources==5.12.0
|
37 |
+
itsdangerous==2.1.2
|
38 |
+
jieba==0.42.1
|
39 |
+
Jinja2==3.1.2
|
40 |
+
joblib==1.2.0
|
41 |
+
kiwisolver==1.4.4
|
42 |
+
lit==16.0.5
|
43 |
+
markdown-it-py==2.2.0
|
44 |
+
MarkupSafe==2.1.2
|
45 |
+
matplotlib==3.7.1
|
46 |
+
mdurl==0.1.2
|
47 |
+
mpmath==1.3.0
|
48 |
+
multidict==6.0.4
|
49 |
+
multiprocess==0.70.12.2
|
50 |
+
networkx==3.1
|
51 |
+
numpy==1.24.3
|
52 |
+
nvidia-cublas-cu11==11.10.3.66
|
53 |
+
nvidia-cuda-cupti-cu11==11.7.101
|
54 |
+
nvidia-cuda-nvrtc-cu11==11.7.99
|
55 |
+
nvidia-cuda-runtime-cu11==11.7.99
|
56 |
+
nvidia-cudnn-cu11==8.5.0.96
|
57 |
+
nvidia-cufft-cu11==10.9.0.58
|
58 |
+
nvidia-curand-cu11==10.2.10.91
|
59 |
+
nvidia-cusolver-cu11==11.4.0.1
|
60 |
+
nvidia-cusparse-cu11==11.7.4.91
|
61 |
+
nvidia-nccl-cu11==2.14.3
|
62 |
+
nvidia-nvtx-cu11==11.7.91
|
63 |
+
opt-einsum==3.3.0
|
64 |
+
packaging==23.1
|
65 |
+
paddle-bfloat==0.1.7
|
66 |
+
paddle2onnx==1.0.6
|
67 |
+
paddlefsl==1.1.0
|
68 |
+
paddlenlp==2.5.2
|
69 |
+
paddlepaddle==2.4.2
|
70 |
+
pandas==2.0.2
|
71 |
+
Pillow==9.5.0
|
72 |
+
protobuf==3.20.0
|
73 |
+
pyarrow==12.0.0
|
74 |
+
pycryptodome==3.18.0
|
75 |
+
pydantic==1.10.8
|
76 |
+
Pygments==2.15.1
|
77 |
+
pyparsing==3.0.9
|
78 |
+
python-dateutil==2.8.2
|
79 |
+
pytz==2023.3
|
80 |
+
PyYAML==6.0
|
81 |
+
regex==2023.5.5
|
82 |
+
requests==2.31.0
|
83 |
+
responses==0.18.0
|
84 |
+
rich==13.4.1
|
85 |
+
scikit-learn==1.2.2
|
86 |
+
scipy==1.10.1
|
87 |
+
sentencepiece==0.1.99
|
88 |
+
seqeval==1.2.2
|
89 |
+
six==1.16.0
|
90 |
+
smart-open==6.3.0
|
91 |
+
sniffio==1.3.0
|
92 |
+
starlette==0.27.0
|
93 |
+
sympy==1.12
|
94 |
+
threadpoolctl==3.1.0
|
95 |
+
tokenizers==0.13.3
|
96 |
+
torch==2.0.1
|
97 |
+
tqdm==4.65.0
|
98 |
+
transformers==4.29.2
|
99 |
+
triton==2.0.0
|
100 |
+
typer==0.9.0
|
101 |
+
typing-extensions==4.6.2
|
102 |
+
tzdata==2023.3
|
103 |
+
urllib3==2.0.2
|
104 |
+
uvicorn==0.22.0
|
105 |
+
visualdl==2.4.2
|
106 |
+
Werkzeug==2.3.4
|
107 |
+
xxhash==3.2.0
|
108 |
+
yarl==1.9.2
|
109 |
+
zipp==3.15.0
|
run.py
CHANGED
@@ -40,3 +40,4 @@ ans = util.generation(groups, max_length)
|
|
40 |
# {(main_sentence,(Ai_abstruct,paragraph))}
|
41 |
for i in ans.items():
|
42 |
print(i)
|
|
|
|
40 |
# {(main_sentence,(Ai_abstruct,paragraph))}
|
41 |
for i in ans.items():
|
42 |
print(i)
|
43 |
+
``
|
util.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import jieba
|
3 |
+
import re
|
4 |
+
import requests
|
5 |
+
import backoff
|
6 |
+
|
7 |
+
|
8 |
+
@backoff.on_exception(backoff.expo, requests.exceptions.RequestException)
|
9 |
+
def post_url(url, headers, payload):
|
10 |
+
response = requests.request("POST", url, headers=headers, data=payload)
|
11 |
+
return response
|
12 |
+
|
13 |
+
|
14 |
+
def seg(text):
|
15 |
+
sentences = re.split(r'(?<=[。!?])\s*', text)
|
16 |
+
return sentences
|
17 |
+
|
18 |
+
|
19 |
+
def clean_text(text):
|
20 |
+
text = text.replace('\n', " ")
|
21 |
+
text = re.sub(r"-", " ", text)
|
22 |
+
text = re.sub(r"\d+/\d+/\d+", "", text) # 日期
|
23 |
+
text = re.sub(r"[0-2]?[0-9]:[0-6][0-9]", "", text) # 时间
|
24 |
+
text = re.sub(
|
25 |
+
r"/[a-zA-Z]*[:\//\]*[A-Za-z0-9\-_]+\.+[A-Za-z0-9\.\/%&=\?\-_]+/i", "", text) # 网址
|
26 |
+
pure_text = ''
|
27 |
+
for letter in text:
|
28 |
+
if letter.isalpha() or letter == ' ':
|
29 |
+
pure_text += letter
|
30 |
+
|
31 |
+
text = ' '.join(word for word in pure_text.split() if len(word) > 1)
|
32 |
+
return text
|
33 |
+
|
34 |
+
|
35 |
+
def article_to_group(groups, topics):
|
36 |
+
para = {}
|
37 |
+
for i in groups:
|
38 |
+
if not i[1] in para:
|
39 |
+
para[i[1]] = i[0]
|
40 |
+
else:
|
41 |
+
para[i[1]] = para[i[1]] + i[0]
|
42 |
+
return para
|
43 |
+
|
44 |
+
|
45 |
+
def generation(para, max_length):
|
46 |
+
API_KEY = "IZt1uK9PAI0LiqleqT0cE30b"
|
47 |
+
SECRET_KEY = "Xv5kHB8eyhNuI1B1G7fRgm2SIPdlxGxs"
|
48 |
+
|
49 |
+
def get_access_token():
|
50 |
+
|
51 |
+
url = "https://aip.baidubce.com/oauth/2.0/token"
|
52 |
+
params = {"grant_type": "client_credentials",
|
53 |
+
"client_id": API_KEY, "client_secret": SECRET_KEY}
|
54 |
+
return str(requests.post(url, params=params).json().get("access_token"))
|
55 |
+
|
56 |
+
url = "https://aip.baidubce.com/rpc/2.0/nlp/v1/news_summary?charset=UTF-8&access_token=" + get_access_token()
|
57 |
+
topic = {}
|
58 |
+
|
59 |
+
for i, (j, k) in enumerate(para.items()):
|
60 |
+
input_text = k
|
61 |
+
# print(k)
|
62 |
+
payload = json.dumps({
|
63 |
+
"content": k,
|
64 |
+
"max_summary_len": max_length
|
65 |
+
})
|
66 |
+
headers = {
|
67 |
+
'Content-Type': 'application/json',
|
68 |
+
'Accept': 'application/json'
|
69 |
+
}
|
70 |
+
|
71 |
+
response = post_url(url, headers, payload)
|
72 |
+
text_dict = json.loads(response.text)
|
73 |
+
# print(text_dict)
|
74 |
+
topic[j] = (text_dict['summary'], k)
|
75 |
+
return topic
|