File size: 4,760 Bytes
928f123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import ast
import copy
import toml
from string import Template
from pathlib import Path
from flatdict import FlatDict
import google.generativeai as genai

from gen.utils import parse_first_json_snippet

def determine_model_name(given_image=None):
  if given_image is None:
    return "gemini-pro"
  else:
    return "gemini-pro-vision"

def construct_image_part(given_image):
  return {
    "mime_type": "image/jpeg",
    "data": given_image
  }

def call_gemini(prompt="", API_KEY=None, given_text=None, given_image=None, generation_config=None, safety_settings=None):
    genai.configure(api_key=API_KEY)

    if generation_config is None:
        generation_config = {
            "temperature": 0.8,
            "top_p": 1,
            "top_k": 32,
            "max_output_tokens": 4096,
        }

    if safety_settings is None:
        safety_settings = [
            {
                "category": "HARM_CATEGORY_HARASSMENT",
                "threshold": "BLOCK_NONE"
            },
            {
                "category": "HARM_CATEGORY_HATE_SPEECH",
                "threshold": "BLOCK_NONE"
            },
            {
                "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                "threshold": "BLOCK_NONE"
            },
            {
                "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
                "threshold": "BLOCK_NONE"
            },
        ]

    model_name = determine_model_name(given_image)
    model = genai.GenerativeModel(model_name=model_name,
                                generation_config=generation_config,
                                safety_settings=safety_settings)

    USER_PROMPT = prompt
    if given_text is not None:
        USER_PROMPT += f"""{prompt}
    ------------------------------------------------
    {given_text}
    """
    prompt_parts = [USER_PROMPT]
    if given_image is not None:
        prompt_parts.append(construct_image_part(given_image))

    response = model.generate_content(prompt_parts)
    return response.text

def try_out(prompt, given_text, gemini_api_key, given_image=None, retry_num=5):
    qna_json = None
    cur_retry = 0

    while qna_json is None and cur_retry < retry_num:
        try:
            qna = call_gemini(
                prompt=prompt,
                given_text=given_text,
                given_image=given_image,
                API_KEY=gemini_api_key
            )

            qna_json = parse_first_json_snippet(qna)
        except Exception as e:
            cur_retry = cur_retry + 1
            print(f"......retry {e}")

    return qna_json

def get_basic_qa(text, gemini_api_key, trucate=7000):
    prompts = toml.load(Path('.') / 'constants' / 'prompts.toml')
    basic_qa = try_out(prompts['basic_qa']['prompt'], text[:trucate], gemini_api_key=gemini_api_key)
    return basic_qa


def get_deep_qa(text, basic_qa, gemini_api_key, trucate=7000):
    prompts = toml.load(Path('.') / 'constants' / 'prompts.toml')

    title = basic_qa['title']
    qnas = copy.deepcopy(basic_qa['qna'])

    for idx, qna in enumerate(qnas):
        q = qna['question']
        a_expert = qna['answers']['expert']

        depth_search_prompt = Template(prompts['deep_qa']['prompt']).substitute(
            title=title, previous_question=q, previous_answer=a_expert, tone="in-depth"
        )
        breath_search_prompt = Template(prompts['deep_qa']['prompt']).substitute(
            title=title, previous_question=q, previous_answer=a_expert, tone="broad"
        )        

        depth_search_response = {}
        breath_search_response = {}

        while 'follow up question' not in depth_search_response or \
            'answers' not in depth_search_response or \
            'eli5' not in depth_search_response['answers'] or \
            'expert' not in depth_search_response['answers']:
            depth_search_response = try_out(depth_search_prompt, text[:trucate], gemini_api_key=gemini_api_key)

        while 'follow up question' not in breath_search_response or \
            'answers' not in breath_search_response or \
            'eli5' not in breath_search_response['answers'] or \
            'expert' not in breath_search_response['answers']:
            breath_search_response = try_out(breath_search_prompt, text[:trucate], gemini_api_key=gemini_api_key)

        if depth_search_response is not None:
            qna['additional_depth_q'] = depth_search_response
        if breath_search_response is not None:
            qna['additional_breath_q'] = breath_search_response

        qna = FlatDict(qna)
        qna_tmp = copy.deepcopy(qna)
        for k in qna_tmp:
            value = qna.pop(k)
            qna[f'{idx}_{k}'] = value
        basic_qa.update(ast.literal_eval(str(qna)))

    return basic_qa