Spaces:
Runtime error
Runtime error
import re | |
import json | |
import argparse | |
import openai | |
import gradio as gr | |
from functools import partial | |
class GPT4News(): | |
def __init__(self, prompt_formats): | |
self.name2prompt = {x['name']: x for x in prompt_formats} | |
def preprocess(self, function_name, input_txt): | |
if not self.name2prompt[function_name]['pre_filter']: | |
return [input_txt] | |
max_length = self.name2prompt[function_name]['split_length'] | |
max_convs = self.name2prompt[function_name]['split_round'] | |
input_txt = re.sub(r'(说话人)(\d+ \d\d:\d\d)', r'Speaker \2', input_txt) | |
speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)') | |
input_txt = speaker_pattern.split(input_txt) | |
input_txt = [x.strip().replace('\n', ' ') for x in input_txt] | |
conversations = [] | |
for idx, txt in enumerate(input_txt): | |
if speaker_pattern.match(txt): | |
if idx < len(input_txt) - 1: | |
if not speaker_pattern.match(input_txt[idx + 1]): | |
conv = [txt, input_txt[idx + 1]] | |
else: | |
conv = [txt, ''] | |
while len(''.join(conv)) > max_length: | |
pruned_len = max_length - len(''.join(conv[0])) | |
pruned_conv = [txt, conv[1][:pruned_len]] | |
conversations.append(pruned_conv) | |
conv = [txt, conv[-1][pruned_len:]] | |
conversations.append(conv) | |
input_txt_list = [''] | |
for conv in conversations: | |
conv_length = len(''.join(conv)) | |
if len(input_txt_list[-1]) + conv_length >= max_length: | |
input_txt_list.append('') | |
elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs: | |
input_txt_list.append('') | |
input_txt_list[-1] += ''.join(conv) | |
processed_txt_list = [] | |
for input_txt in input_txt_list: | |
input_txt = ''.join(input_txt) | |
input_txt = speaker_pattern.sub(r'\n\1: ', input_txt) | |
processed_txt_list.append(input_txt.strip()) | |
return processed_txt_list | |
def chatgpt(self, messages, temperature=0.0): | |
try: | |
completion = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
temperature=temperature | |
) | |
return completion.choices[0].message.content | |
except Exception as err: | |
print(err) | |
return self.chatgpt(messages, temperature) | |
def llm(self, function_name, temperature, **kwargs): | |
prompt = self.name2prompt[function_name] | |
user_kwargs = {key: kwargs[key] for key in prompt['user_keys']} | |
user = prompt['user'].format(**user_kwargs) | |
system_kwargs = {key: kwargs[key] for key in prompt['system_keys']} | |
system = prompt['system'].format(**system_kwargs) | |
messages = [ | |
{'role': 'system', | |
'content': system}, | |
{'role': 'user', | |
'content': user}] | |
response = self.chatgpt(messages, temperature=temperature) | |
print(f'SYSTEM:\n\n{system}') | |
print(f'USER:\n\n{user}') | |
print(f'RESPONSE:\n\n{response}') | |
return response | |
def translate(self, txt, output_lang): | |
if output_lang == 'English': | |
return txt | |
system = 'You are a translator.' | |
user = 'Translate the following text to {}:\n\n{}'.format( | |
output_lang, txt) | |
messages = [{'role': 'system', 'content': system}, | |
{'role': 'user', 'content': user}] | |
response = self.chatgpt(messages) | |
print(f'SYSTEM:\n\n{system}') | |
print(f'USER:\n\n{user}') | |
print(f'RESPONSE:\n\n{response}') | |
return response | |
def postprocess(self, function_name, input_txt, output_txt_list, | |
output_lang): | |
if not self.name2prompt[function_name]['post_filter']: | |
output_txt = '\n\n'.join(output_txt_list) | |
output_txt = self.translate(output_txt, output_lang) | |
return output_txt | |
speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)') | |
output_txt = [] | |
for txt in output_txt_list: | |
if len(speaker_pattern.findall(txt)) > 0: | |
output_txt.append(txt) | |
output_txt = ''.join(output_txt) | |
speakers = set(speaker_pattern.findall(input_txt)) | |
output_txt = speaker_pattern.split(output_txt) | |
results = [] | |
for idx, txt in enumerate(output_txt): | |
if speaker_pattern.match(txt): | |
if txt not in speakers: | |
continue | |
if idx < len(output_txt) - 1: | |
if not speaker_pattern.match(output_txt[idx + 1]): | |
res = txt + output_txt[idx + 1] | |
else: | |
res = txt | |
res = self.translate(res, output_lang) | |
results.append(res.strip()) | |
return '\n\n'.join(results) | |
def __call__(self, api_key, function_name, temperature, output_lang, | |
input_txt, tags): | |
if api_key is None or api_key == '': | |
return 'OPENAI API Key is not set.' | |
if function_name is None or function_name == '': | |
return 'Function is not selected.' | |
openai.api_key = api_key | |
input_txt_list = self.preprocess(function_name, input_txt) | |
input_txt = '\n'.join(input_txt_list) | |
output_txt_list = [] | |
for txt in input_txt_list: | |
llm_kwargs = dict(input_txt=txt, | |
tags=tags) | |
output_txt = self.llm(function_name, temperature, **llm_kwargs) | |
output_txt_list.append(output_txt) | |
output_txt = self.postprocess( | |
function_name, input_txt, output_txt_list, output_lang) | |
return output_txt | |
def function_names(self): | |
return self.name2prompt.keys() | |
def function_name_select_callback(componments, name2prompt, function_name): | |
prompt = name2prompt[function_name] | |
user_keys = prompt['user_keys'] | |
result = [] | |
for comp in componments: | |
result.append(gr.update(visible=comp in user_keys)) | |
return result | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--prompt', type=str, default='prompts/interview.json', | |
help='path to the prompt file') | |
parser.add_argument('--temperature', type=float, default='0.7', | |
help='temperature for the llm model') | |
args = parser.parse_args() | |
prompt_formats = json.load(open(args.prompt, 'r')) | |
gpt4news = GPT4News(prompt_formats) | |
languages = ['Arabic', 'Bengali', 'Chinese (Simplified)', | |
'Chinese (Traditional)', 'Dutch', 'English', 'French', | |
'German', 'Hindi', 'Italian', 'Japanese', 'Korean', | |
'Portuguese', 'Punjabi', 'Russian', 'Spanish', 'Turkish', | |
'Urdu'] | |
default_func = sorted(gpt4news.function_names)[0] | |
default_user_keys = gpt4news.name2prompt[default_func]['user_keys'] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(scale=0.3): | |
with gr.Row(): | |
api_key = gr.Textbox( | |
lines=1, | |
label='OPENAI API Key', | |
elem_id='api_key_textbox', | |
placeholder='Enter your OPENAI API Key') | |
with gr.Row(): | |
function_name = gr.Dropdown( | |
sorted(gpt4news.function_names), | |
value=default_func, | |
elem_id='function_dropdown', | |
label='Function', | |
info='choose a function to run') | |
with gr.Row(): | |
output_lang = gr.Dropdown( | |
languages, | |
value='English', | |
elem_id='output_lang_dropdown', | |
label='Output Language', | |
info='choose a language to output') | |
with gr.Row(): | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=args.temperature, | |
step=0.1, | |
interactive=True, | |
label='Temperature', | |
info='higher temperature means more creative') | |
with gr.Row(): | |
tags = gr.Textbox( | |
lines=1, | |
visible='tags' in default_user_keys, | |
label='Tags', | |
elem_id='tags_textbox', | |
placeholder='Enter tags split by semicolon') | |
with gr.Row(): | |
input_txt = gr.Textbox( | |
lines=4, | |
visible='input_txt' in default_user_keys, | |
label='Input', | |
elem_id='input_textbox', | |
placeholder='Enter text and press submit') | |
with gr.Row(): | |
submit = gr.Button('Submit') | |
with gr.Row(): | |
clear = gr.Button('Clear') | |
with gr.Column(scale=0.7): | |
output_txt = gr.Textbox( | |
lines=8, | |
label='Output', | |
elem_id='output_textbox') | |
function_name.select( | |
partial(function_name_select_callback, ['input_txt', 'tags'], | |
gpt4news.name2prompt), | |
[function_name], | |
[input_txt, tags] | |
) | |
submit.click( | |
gpt4news, | |
[api_key, function_name, temperature, output_lang, | |
input_txt, tags], | |
[output_txt]) | |
clear.click( | |
lambda: ['', '', ''], | |
None, | |
tags, input_txt) | |
demo.queue(concurrency_count=6) | |
demo.launch() | |