Spaces:
Sleeping
Sleeping
File size: 4,760 Bytes
928f123 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import ast
import copy
import toml
from string import Template
from pathlib import Path
from flatdict import FlatDict
import google.generativeai as genai
from gen.utils import parse_first_json_snippet
def determine_model_name(given_image=None):
if given_image is None:
return "gemini-pro"
else:
return "gemini-pro-vision"
def construct_image_part(given_image):
return {
"mime_type": "image/jpeg",
"data": given_image
}
def call_gemini(prompt="", API_KEY=None, given_text=None, given_image=None, generation_config=None, safety_settings=None):
genai.configure(api_key=API_KEY)
if generation_config is None:
generation_config = {
"temperature": 0.8,
"top_p": 1,
"top_k": 32,
"max_output_tokens": 4096,
}
if safety_settings is None:
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
},
]
model_name = determine_model_name(given_image)
model = genai.GenerativeModel(model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings)
USER_PROMPT = prompt
if given_text is not None:
USER_PROMPT += f"""{prompt}
------------------------------------------------
{given_text}
"""
prompt_parts = [USER_PROMPT]
if given_image is not None:
prompt_parts.append(construct_image_part(given_image))
response = model.generate_content(prompt_parts)
return response.text
def try_out(prompt, given_text, gemini_api_key, given_image=None, retry_num=5):
qna_json = None
cur_retry = 0
while qna_json is None and cur_retry < retry_num:
try:
qna = call_gemini(
prompt=prompt,
given_text=given_text,
given_image=given_image,
API_KEY=gemini_api_key
)
qna_json = parse_first_json_snippet(qna)
except Exception as e:
cur_retry = cur_retry + 1
print(f"......retry {e}")
return qna_json
def get_basic_qa(text, gemini_api_key, trucate=7000):
prompts = toml.load(Path('.') / 'constants' / 'prompts.toml')
basic_qa = try_out(prompts['basic_qa']['prompt'], text[:trucate], gemini_api_key=gemini_api_key)
return basic_qa
def get_deep_qa(text, basic_qa, gemini_api_key, trucate=7000):
prompts = toml.load(Path('.') / 'constants' / 'prompts.toml')
title = basic_qa['title']
qnas = copy.deepcopy(basic_qa['qna'])
for idx, qna in enumerate(qnas):
q = qna['question']
a_expert = qna['answers']['expert']
depth_search_prompt = Template(prompts['deep_qa']['prompt']).substitute(
title=title, previous_question=q, previous_answer=a_expert, tone="in-depth"
)
breath_search_prompt = Template(prompts['deep_qa']['prompt']).substitute(
title=title, previous_question=q, previous_answer=a_expert, tone="broad"
)
depth_search_response = {}
breath_search_response = {}
while 'follow up question' not in depth_search_response or \
'answers' not in depth_search_response or \
'eli5' not in depth_search_response['answers'] or \
'expert' not in depth_search_response['answers']:
depth_search_response = try_out(depth_search_prompt, text[:trucate], gemini_api_key=gemini_api_key)
while 'follow up question' not in breath_search_response or \
'answers' not in breath_search_response or \
'eli5' not in breath_search_response['answers'] or \
'expert' not in breath_search_response['answers']:
breath_search_response = try_out(breath_search_prompt, text[:trucate], gemini_api_key=gemini_api_key)
if depth_search_response is not None:
qna['additional_depth_q'] = depth_search_response
if breath_search_response is not None:
qna['additional_breath_q'] = breath_search_response
qna = FlatDict(qna)
qna_tmp = copy.deepcopy(qna)
for k in qna_tmp:
value = qna.pop(k)
qna[f'{idx}_{k}'] = value
basic_qa.update(ast.literal_eval(str(qna)))
return basic_qa |