paper_qa / gen /gemini.py
chansung's picture
Update gen/gemini.py
6e0c0c0 verified
raw
history blame
4.77 kB
import ast
import copy
import toml
from string import Template
from pathlib import Path
from flatdict import FlatDict
import google.generativeai as genai
from gen.utils import parse_first_json_snippet
def determine_model_name(given_image=None):
if given_image is None:
return "gemini-1.0-pro"
else:
return "gemini-1.0-pro-vision"
def construct_image_part(given_image):
return {
"mime_type": "image/jpeg",
"data": given_image
}
def call_gemini(prompt="", API_KEY=None, given_text=None, given_image=None, generation_config=None, safety_settings=None):
genai.configure(api_key=API_KEY)
if generation_config is None:
generation_config = {
"temperature": 0.8,
"top_p": 1,
"top_k": 32,
"max_output_tokens": 4096,
}
if safety_settings is None:
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
},
]
model_name = determine_model_name(given_image)
model = genai.GenerativeModel(model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings)
USER_PROMPT = prompt
if given_text is not None:
USER_PROMPT += f"""{prompt}
------------------------------------------------
{given_text}
"""
prompt_parts = [USER_PROMPT]
if given_image is not None:
prompt_parts.append(construct_image_part(given_image))
response = model.generate_content(prompt_parts)
return response.text
def try_out(prompt, given_text, gemini_api_key, given_image=None, retry_num=10):
qna_json = None
cur_retry = 0
while qna_json is None and cur_retry < retry_num:
try:
qna = call_gemini(
prompt=prompt,
given_text=given_text,
given_image=given_image,
API_KEY=gemini_api_key
)
qna_json = parse_first_json_snippet(qna)
except Exception as e:
cur_retry = cur_retry + 1
print(f"......retry {e}")
return qna_json
def get_basic_qa(text, gemini_api_key, trucate=7000):
prompts = toml.load(Path('.') / 'constants' / 'prompts.toml')
basic_qa = try_out(prompts['basic_qa']['prompt'], text[:trucate], gemini_api_key=gemini_api_key)
return basic_qa
def get_deep_qa(text, basic_qa, gemini_api_key, trucate=7000):
prompts = toml.load(Path('.') / 'constants' / 'prompts.toml')
title = basic_qa['title']
qnas = copy.deepcopy(basic_qa['qna'])
for idx, qna in enumerate(qnas):
q = qna['question']
a_expert = qna['answers']['expert']
depth_search_prompt = Template(prompts['deep_qa']['prompt']).substitute(
title=title, previous_question=q, previous_answer=a_expert, tone="in-depth"
)
breath_search_prompt = Template(prompts['deep_qa']['prompt']).substitute(
title=title, previous_question=q, previous_answer=a_expert, tone="broad"
)
depth_search_response = {}
breath_search_response = {}
while 'follow up question' not in depth_search_response or \
'answers' not in depth_search_response or \
'eli5' not in depth_search_response['answers'] or \
'expert' not in depth_search_response['answers']:
depth_search_response = try_out(depth_search_prompt, text[:trucate], gemini_api_key=gemini_api_key)
while 'follow up question' not in breath_search_response or \
'answers' not in breath_search_response or \
'eli5' not in breath_search_response['answers'] or \
'expert' not in breath_search_response['answers']:
breath_search_response = try_out(breath_search_prompt, text[:trucate], gemini_api_key=gemini_api_key)
if depth_search_response is not None:
qna['additional_depth_q'] = depth_search_response
if breath_search_response is not None:
qna['additional_breath_q'] = breath_search_response
qna = FlatDict(qna)
qna_tmp = copy.deepcopy(qna)
for k in qna_tmp:
value = qna.pop(k)
qna[f'{idx}_{k}'] = value
basic_qa.update(ast.literal_eval(str(qna)))
return basic_qa