Mbonea commited on
Commit
4e0c974
1 Parent(s): f337592

summarization improved

Browse files
App/Chat/ChatRoutes.py CHANGED
@@ -5,7 +5,8 @@ from App.Users.Schemas import UserSchema
5
  from App.Transcription.Model import Transcriptions
6
  from App.Transcription.Schemas import *
7
  from App import bot
8
- from .utils.PalmAPI import generate_summary,summarization
 
9
  import aiohttp
10
  import os
11
 
@@ -39,10 +40,6 @@ async def generate_message( task_id: str,
39
  text =''
40
  for item in result.content:
41
  text+=item['text']
42
- docs=generate_summary(text)
43
- summaries =[]
44
- for doc in docs:
45
- summary=await summarization(doc.page_content)
46
- summaries.append(summary)
47
- return summaries
48
 
 
5
  from App.Transcription.Model import Transcriptions
6
  from App.Transcription.Schemas import *
7
  from App import bot
8
+ from .utils.Summarize import Summarizer
9
+
10
  import aiohttp
11
  import os
12
 
 
40
  text =''
41
  for item in result.content:
42
  text+=item['text']
43
+ summary=await Summarizer(text)
44
+ return [summary]
 
 
 
 
45
 
App/Chat/utils/{PalmAPI.py → Dev/PalmAPI.py} RENAMED
@@ -1,15 +1,52 @@
1
  import aiohttp
2
  import asyncio
3
  import google.generativeai as palm
 
 
 
 
4
  import os
5
  PALM_API = ""
6
  API_KEY=os.environ.get("PALM_API",PALM_API)
7
  palm.configure(api_key=API_KEY)
8
 
9
-
10
-
11
- from langchain.text_splitter import RecursiveCharacterTextSplitter
12
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def count_tokens(text):
15
  return palm.count_message_tokens(prompt=text)['token_count']
@@ -108,5 +145,5 @@ Yo, Mabu, you really the only independent artist putting up numbers right now, b
108
 
109
 
110
 
111
- if __name__ == '__main__':
112
- asyncio.run(main=main())
 
1
  import aiohttp
2
  import asyncio
3
  import google.generativeai as palm
4
+ from langchain.llms import GooglePalm
5
+ from langchain.chains.summarize import load_summarize_chain
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain import PromptTemplate
8
  import os
9
  PALM_API = ""
10
  API_KEY=os.environ.get("PALM_API",PALM_API)
11
  palm.configure(api_key=API_KEY)
12
 
13
+ llm = GooglePalm(google_api_key=API_KEY, safety_settings= [
14
+ {"category": "HARM_CATEGORY_DEROGATORY", "threshold": 4},
15
+ {"category": "HARM_CATEGORY_TOXICITY", "threshold": 4},
16
+ {"category": "HARM_CATEGORY_VIOLENCE", "threshold": 4},
17
+ {"category": "HARM_CATEGORY_SEXUAL", "threshold": 4},
18
+ {"category": "HARM_CATEGORY_MEDICAL", "threshold": 4},
19
+ {"category": "HARM_CATEGORY_DANGEROUS", "threshold": 4},
20
+ ],)
21
+ text_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n"], chunk_size=10000, chunk_overlap=500)
22
+ summary_chain = load_summarize_chain(llm=llm, chain_type='map_reduce',
23
+ # verbose=True # Set verbose=True if you want to see the prompts being used
24
+ )
25
+ essay= ''' TFC Mamma Ron Subway Galito Urban Heart Kootie Java Square In this video, I'm going to try every single fast food chain in Irobi Kenya and I'm going to rate them on a scale of terrible, bad, mid, good and for the incredible ones, go zest! I've broken them up into categories, so pizza category, burger category, chicken, general fast food and breakfast category and I'm starting with Pizza Hut To keep this fair across all restaurants, I'm ordering the cheapest possible meal on the menu or as close to my budget of 500 Kenya shillings and for that price in Pizza Hut Okay, so this is the mine meat lovers pizza This is going to be my first tasting of Pizza Hut in Irobi Kenya I haven't washed my hands, no one has to know that Okay, I could already feel how chunky this pizza is Maybe dip that in this barbecue sauce Mmm Okay, '''
26
+ docs = text_splitter.create_documents([essay])
27
+ # print(docs[0].page_content)
28
+ map_prompt = """
29
+ Write a concise summary of the following:
30
+ "{text}"
31
+ CONCISE SUMMARY:
32
+ """
33
+ combine_prompt = """
34
+ Write a concise summary of the following text delimited by triple backquotes.
35
+ Return your response in bullet points which covers the key points of the text.
36
+ ```{text}```
37
+ BULLET POINT SUMMARY:
38
+ """
39
+ combine_prompt_template = PromptTemplate(template=combine_prompt, input_variables=["text"])
40
+ map_prompt_template = PromptTemplate(template=map_prompt, input_variables=["text"])
41
+
42
+ summary_chain = load_summarize_chain(llm=llm,
43
+ chain_type='map_reduce',
44
+ map_prompt=map_prompt_template,
45
+ combine_prompt=combine_prompt_template,
46
+ verbose=True
47
+ )
48
+ output = summary_chain.run(docs)
49
+ print(output)
50
 
51
  def count_tokens(text):
52
  return palm.count_message_tokens(prompt=text)['token_count']
 
145
 
146
 
147
 
148
+ # if __name__ == '__main__':
149
+ # asyncio.run(main=main())
App/Chat/utils/RAG.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import aiohttp
2
+ import asyncio,pprint
3
+ import google.generativeai as palm
4
+ from langchain.chains.question_answering import load_qa_chain
5
+ from langchain.llms import GooglePalm
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain import PromptTemplate
8
+ import os
9
+ PALM_API = ''
10
+ API_KEY=os.environ.get("PALM_API",PALM_API)
11
+ palm.configure(api_key=API_KEY)
12
+
13
+
14
+ def count_tokens(text):
15
+ return palm.count_message_tokens(prompt=text)['token_count']
16
+ llm = GooglePalm(
17
+ google_api_key=API_KEY, **{ "safety_settings": [
18
+ {"category": "HARM_CATEGORY_DEROGATORY", "threshold": 4},
19
+ {"category": "HARM_CATEGORY_TOXICITY", "threshold": 4},
20
+ {"category": "HARM_CATEGORY_VIOLENCE", "threshold": 4},
21
+ {"category": "HARM_CATEGORY_SEXUAL", "threshold": 4},
22
+ {"category": "HARM_CATEGORY_MEDICAL", "threshold": 4},
23
+ {"category": "HARM_CATEGORY_DANGEROUS", "threshold": 4},
24
+ ]})
25
+ text_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n","."], chunk_size=40_000, chunk_overlap=500)
26
+ with open('./sample.txt', 'r') as file:
27
+ essay = file.read()
28
+
29
+ docs = text_splitter.create_documents([essay])
30
+ for doc in docs:
31
+ print(count_tokens(doc.page_content))
32
+
App/Chat/utils/Summarize.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import aiohttp
2
+ import asyncio,pprint
3
+ import google.generativeai as palm
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain import PromptTemplate
6
+ import os
7
+ PALM_API = ''
8
+ API_KEY=os.environ.get("PALM_API",PALM_API)
9
+ palm.configure(api_key=API_KEY)
10
+
11
+
12
+
13
+
14
+ text_splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n","."], chunk_size=40_000, chunk_overlap=500)
15
+
16
+
17
+ map_prompt = """
18
+ Write a verbose summary like a masters student of the following:
19
+ "{text}"
20
+ CONCISE SUMMARY:
21
+ """
22
+
23
+
24
+ combine_prompt = """
25
+ Write a concise summary of the following text delimited by triple backquotes.
26
+ Return your response in a detailed verbose paragraph which covers the text. Make it as insightful to the reader as possible, write like a masters student.
27
+
28
+ ```{text}```
29
+
30
+ SUMMARY:
31
+ """
32
+ def count_tokens(text):
33
+ return palm.count_message_tokens(prompt=text)['token_count']
34
+
35
+
36
+ async def PalmTextModel(text,candidates=1):
37
+ url = f"https://generativelanguage.googleapis.com/v1beta2/models/text-bison-001:generateText?key={API_KEY}"
38
+
39
+ headers = {
40
+ "Content-Type": "application/json",
41
+ }
42
+
43
+ data = {
44
+ "prompt": {
45
+ "text": text
46
+ },
47
+ "temperature": 0.95,
48
+ "top_k": 100,
49
+ "top_p": 0.95,
50
+ "candidate_count": candidates,
51
+ "max_output_tokens": 1024,
52
+ "stop_sequences": ["</output>"],
53
+ "safety_settings": [
54
+ {"category": "HARM_CATEGORY_DEROGATORY", "threshold": 4},
55
+ {"category": "HARM_CATEGORY_TOXICITY", "threshold": 4},
56
+ {"category": "HARM_CATEGORY_VIOLENCE", "threshold": 4},
57
+ {"category": "HARM_CATEGORY_SEXUAL", "threshold": 4},
58
+ {"category": "HARM_CATEGORY_MEDICAL", "threshold": 4},
59
+ {"category": "HARM_CATEGORY_DANGEROUS", "threshold": 4},
60
+ ],
61
+ }
62
+
63
+
64
+ async with aiohttp.ClientSession() as session:
65
+ async with session.post(url, json=data, headers=headers) as response:
66
+ if response.status == 200:
67
+ result = await response.json()
68
+ # print(result)
69
+ if candidates>1:
70
+ temp = [candidate["output"] for candidate in result["candidates"]]
71
+ return temp
72
+ temp = result["candidates"][0]["output"]
73
+ return temp
74
+ else:
75
+ print(f"Error: {response.status}\n{await response.text()}")
76
+
77
+
78
+ async def Summarizer(essay):
79
+
80
+ docs = text_splitter.create_documents([essay])
81
+
82
+ #for 1 large document
83
+ if len(docs) == 1:
84
+ tasks = [PalmTextModel(combine_prompt.format(text=doc.page_content)) for doc in docs]
85
+ # Gather and execute the tasks concurrently
86
+ responses = await asyncio.gather(*tasks)
87
+ ans=" ".join(responses)
88
+ return ans
89
+
90
+ tasks = [PalmTextModel(map_prompt.format(text=doc.page_content)) for doc in docs]
91
+ # Gather and execute the tasks concurrently
92
+ responses = await asyncio.gather(*tasks)
93
+ main=" ".join(responses)
94
+ ans=await PalmTextModel(combine_prompt.format(text=main),candidates=1)
95
+ return ans
96
+
97
+