hellopahe commited on
Commit
90f83ff
1 Parent(s): 6a62e87
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -1,8 +1,12 @@
 
1
  import torch
2
  import gradio as gr
3
 
4
  from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline
5
  from article_extractor.tokenizers_pegasus import PegasusTokenizer
 
 
 
6
 
7
 
8
  class SummaryExtractor(object):
@@ -17,18 +21,22 @@ class SummaryExtractor(object):
17
 
18
 
19
  t_randeng = SummaryExtractor()
 
20
 
21
 
22
  def randeng_extract(content):
23
  return t_randeng.extract(content)
24
 
25
- def similarity_check(query: str, doc: str):
26
- doc_list = doc.split("\n")
27
 
28
- return "similarity result"
 
 
 
 
 
29
 
30
  with gr.Blocks() as app:
31
- gr.Markdown("从下面的标签选择不同的摘要模型, 在左侧输入原文")
32
  # with gr.Tab("CamelBell-Chinese-LoRA"):
33
  # text_input = gr.Textbox()
34
  # text_output = gr.Textbox()
@@ -37,11 +45,6 @@ with gr.Blocks() as app:
37
  text_input_1 = gr.Textbox()
38
  text_output_1 = gr.Textbox()
39
  text_button_1 = gr.Button("生成摘要")
40
- # with gr.Tab("Flip Image"):
41
- # with gr.Row():
42
- # image_input = gr.Image()
43
- # image_output = gr.Image()
44
- # image_button = gr.Button("Flip")
45
  with gr.Tab("相似度检测"):
46
  with gr.Row():
47
  text_input_query = gr.Textbox()
@@ -51,6 +54,6 @@ with gr.Blocks() as app:
51
 
52
  # text_button.click(tuoling_extract, inputs=text_input, outputs=text_output)
53
  text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
54
- text_button_similarity.click(similarity_check, inputs=text_input_query, outputs=text_input_doc)
55
 
56
  app.launch()
 
1
+ import numpy
2
  import torch
3
  import gradio as gr
4
 
5
  from transformers import PegasusForConditionalGeneration, Text2TextGenerationPipeline
6
  from article_extractor.tokenizers_pegasus import PegasusTokenizer
7
+ from embed import Embed
8
+
9
+ import tensorflow as tf
10
 
11
 
12
  class SummaryExtractor(object):
 
21
 
22
 
23
  t_randeng = SummaryExtractor()
24
+ embedder = Embed()
25
 
26
 
27
  def randeng_extract(content):
28
  return t_randeng.extract(content)
29
 
 
 
30
 
31
+ def similarity_check(inputs: list):
32
+ doc_list = inputs[1].split("\n")
33
+ doc_list.append(inputs[0])
34
+ embedding_list = embedder.encode(doc_list)
35
+ scores = (embedding_list[-1] @ tf.transpose(embedding_list[:-1]))[0].numpy().tolist()
36
+ return numpy.array2string(scores, separator=',')
37
 
38
  with gr.Blocks() as app:
39
+ gr.Markdown("从下面的标签选择测试模块 [摘要生成,相似度检测]")
40
  # with gr.Tab("CamelBell-Chinese-LoRA"):
41
  # text_input = gr.Textbox()
42
  # text_output = gr.Textbox()
 
45
  text_input_1 = gr.Textbox()
46
  text_output_1 = gr.Textbox()
47
  text_button_1 = gr.Button("生成摘要")
 
 
 
 
 
48
  with gr.Tab("相似度检测"):
49
  with gr.Row():
50
  text_input_query = gr.Textbox()
 
54
 
55
  # text_button.click(tuoling_extract, inputs=text_input, outputs=text_output)
56
  text_button_1.click(randeng_extract, inputs=text_input_1, outputs=text_output_1)
57
+ text_button_similarity.click(similarity_check, inputs=[text_input_query, text_input_doc], outputs=text_output_similarity)
58
 
59
  app.launch()