GPT4News / app.py
stevengrove's picture
add pre_filter flag
ae0d311
import re
import json
import argparse
import openai
import gradio as gr
from functools import partial
class GPT4News():
def __init__(self, prompt_formats):
self.name2prompt = {x['name']: x for x in prompt_formats}
def preprocess(self, function_name, input_txt):
if not self.name2prompt[function_name]['pre_filter']:
return [input_txt]
max_length = self.name2prompt[function_name]['split_length']
max_convs = self.name2prompt[function_name]['split_round']
input_txt = re.sub(r'(说话人)(\d+ \d\d:\d\d)', r'Speaker \2', input_txt)
speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)')
input_txt = speaker_pattern.split(input_txt)
input_txt = [x.strip().replace('\n', ' ') for x in input_txt]
conversations = []
for idx, txt in enumerate(input_txt):
if speaker_pattern.match(txt):
if idx < len(input_txt) - 1:
if not speaker_pattern.match(input_txt[idx + 1]):
conv = [txt, input_txt[idx + 1]]
else:
conv = [txt, '']
while len(''.join(conv)) > max_length:
pruned_len = max_length - len(''.join(conv[0]))
pruned_conv = [txt, conv[1][:pruned_len]]
conversations.append(pruned_conv)
conv = [txt, conv[-1][pruned_len:]]
conversations.append(conv)
input_txt_list = ['']
for conv in conversations:
conv_length = len(''.join(conv))
if len(input_txt_list[-1]) + conv_length >= max_length:
input_txt_list.append('')
elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs:
input_txt_list.append('')
input_txt_list[-1] += ''.join(conv)
processed_txt_list = []
for input_txt in input_txt_list:
input_txt = ''.join(input_txt)
input_txt = speaker_pattern.sub(r'\n\1: ', input_txt)
processed_txt_list.append(input_txt.strip())
return processed_txt_list
def chatgpt(self, messages, temperature=0.0):
try:
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=temperature
)
return completion.choices[0].message.content
except Exception as err:
print(err)
return self.chatgpt(messages, temperature)
def llm(self, function_name, temperature, **kwargs):
prompt = self.name2prompt[function_name]
user_kwargs = {key: kwargs[key] for key in prompt['user_keys']}
user = prompt['user'].format(**user_kwargs)
system_kwargs = {key: kwargs[key] for key in prompt['system_keys']}
system = prompt['system'].format(**system_kwargs)
messages = [
{'role': 'system',
'content': system},
{'role': 'user',
'content': user}]
response = self.chatgpt(messages, temperature=temperature)
print(f'SYSTEM:\n\n{system}')
print(f'USER:\n\n{user}')
print(f'RESPONSE:\n\n{response}')
return response
def translate(self, txt, output_lang):
if output_lang == 'English':
return txt
system = 'You are a translator.'
user = 'Translate the following text to {}:\n\n{}'.format(
output_lang, txt)
messages = [{'role': 'system', 'content': system},
{'role': 'user', 'content': user}]
response = self.chatgpt(messages)
print(f'SYSTEM:\n\n{system}')
print(f'USER:\n\n{user}')
print(f'RESPONSE:\n\n{response}')
return response
def postprocess(self, function_name, input_txt, output_txt_list,
output_lang):
if not self.name2prompt[function_name]['post_filter']:
output_txt = '\n\n'.join(output_txt_list)
output_txt = self.translate(output_txt, output_lang)
return output_txt
speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)')
output_txt = []
for txt in output_txt_list:
if len(speaker_pattern.findall(txt)) > 0:
output_txt.append(txt)
output_txt = ''.join(output_txt)
speakers = set(speaker_pattern.findall(input_txt))
output_txt = speaker_pattern.split(output_txt)
results = []
for idx, txt in enumerate(output_txt):
if speaker_pattern.match(txt):
if txt not in speakers:
continue
if idx < len(output_txt) - 1:
if not speaker_pattern.match(output_txt[idx + 1]):
res = txt + output_txt[idx + 1]
else:
res = txt
res = self.translate(res, output_lang)
results.append(res.strip())
return '\n\n'.join(results)
def __call__(self, api_key, function_name, temperature, output_lang,
input_txt, tags):
if api_key is None or api_key == '':
return 'OPENAI API Key is not set.'
if function_name is None or function_name == '':
return 'Function is not selected.'
openai.api_key = api_key
input_txt_list = self.preprocess(function_name, input_txt)
input_txt = '\n'.join(input_txt_list)
output_txt_list = []
for txt in input_txt_list:
llm_kwargs = dict(input_txt=txt,
tags=tags)
output_txt = self.llm(function_name, temperature, **llm_kwargs)
output_txt_list.append(output_txt)
output_txt = self.postprocess(
function_name, input_txt, output_txt_list, output_lang)
return output_txt
@property
def function_names(self):
return self.name2prompt.keys()
def function_name_select_callback(componments, name2prompt, function_name):
prompt = name2prompt[function_name]
user_keys = prompt['user_keys']
result = []
for comp in componments:
result.append(gr.update(visible=comp in user_keys))
return result
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--prompt', type=str, default='prompts/interview.json',
help='path to the prompt file')
parser.add_argument('--temperature', type=float, default='0.7',
help='temperature for the llm model')
args = parser.parse_args()
prompt_formats = json.load(open(args.prompt, 'r'))
gpt4news = GPT4News(prompt_formats)
languages = ['Arabic', 'Bengali', 'Chinese (Simplified)',
'Chinese (Traditional)', 'Dutch', 'English', 'French',
'German', 'Hindi', 'Italian', 'Japanese', 'Korean',
'Portuguese', 'Punjabi', 'Russian', 'Spanish', 'Turkish',
'Urdu']
default_func = sorted(gpt4news.function_names)[0]
default_user_keys = gpt4news.name2prompt[default_func]['user_keys']
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=0.3):
with gr.Row():
api_key = gr.Textbox(
lines=1,
label='OPENAI API Key',
elem_id='api_key_textbox',
placeholder='Enter your OPENAI API Key')
with gr.Row():
function_name = gr.Dropdown(
sorted(gpt4news.function_names),
value=default_func,
elem_id='function_dropdown',
label='Function',
info='choose a function to run')
with gr.Row():
output_lang = gr.Dropdown(
languages,
value='English',
elem_id='output_lang_dropdown',
label='Output Language',
info='choose a language to output')
with gr.Row():
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=args.temperature,
step=0.1,
interactive=True,
label='Temperature',
info='higher temperature means more creative')
with gr.Row():
tags = gr.Textbox(
lines=1,
visible='tags' in default_user_keys,
label='Tags',
elem_id='tags_textbox',
placeholder='Enter tags split by semicolon')
with gr.Row():
input_txt = gr.Textbox(
lines=4,
visible='input_txt' in default_user_keys,
label='Input',
elem_id='input_textbox',
placeholder='Enter text and press submit')
with gr.Row():
submit = gr.Button('Submit')
with gr.Row():
clear = gr.Button('Clear')
with gr.Column(scale=0.7):
output_txt = gr.Textbox(
lines=8,
label='Output',
elem_id='output_textbox')
function_name.select(
partial(function_name_select_callback, ['input_txt', 'tags'],
gpt4news.name2prompt),
[function_name],
[input_txt, tags]
)
submit.click(
gpt4news,
[api_key, function_name, temperature, output_lang,
input_txt, tags],
[output_txt])
clear.click(
lambda: ['', '', ''],
None,
tags, input_txt)
demo.queue(concurrency_count=6)
demo.launch()