|
import time |
|
import openai |
|
import gradio as gr |
|
import requests |
|
from pydub import AudioSegment as am |
|
from xml.etree import ElementTree |
|
|
|
aoai_url, aoai_key, stts_key, stts_region = "", "", "", "" |
|
openai.api_type = "azure" |
|
|
|
prompts = "" |
|
model_gpt = "" |
|
messages_gpt = [] |
|
model_chat = "" |
|
messages_chat = [ |
|
{"role": "system", "content": "You are an AI assistant that helps people find information."}, |
|
] |
|
response_walle = [] |
|
model_vchat = "" |
|
messages_vchat = [ |
|
{"role": "system", "content": "You are an AI assistant that helps people find information and just respond with SSML."}, |
|
] |
|
|
|
def get_aoai_set(get_aoai_url, get_aoai_key, get_aoai_API): |
|
if get_aoai_url: |
|
openai.api_base = get_aoai_url |
|
if get_aoai_key: |
|
openai.api_key = get_aoai_key |
|
if get_aoai_API: |
|
openai.api_version = get_aoai_API |
|
return gr.update(value=get_aoai_url), gr.update(value=get_aoai_key), gr.update(value=get_aoai_API) |
|
|
|
def get_stts_set(get_stts_key, get_stts_region): |
|
global stts_key, stts_region |
|
if get_stts_key: |
|
stts_key = get_stts_key |
|
if get_stts_region: |
|
stts_region = get_stts_region |
|
return gr.update(value=get_stts_key), gr.update(value=get_stts_region) |
|
|
|
with gr.Blocks() as page: |
|
with gr.Tabs(): |
|
with gr.TabItem("Settings"): |
|
gr.HTML(""" |
|
<p>Please read and set parameters before switching to another tab.</p> <br>Your Azure OpenAI Key and other Azure Cognitive Service Keys |
|
will not be saved or viewed by anyone. <br><br> |
|
You can find these parameters in Azure Portal. Select Azure OpenAI resource or Cognitive Service resource like Speech, and then select |
|
'Keys and Endpoint' from left panel. <br> For Azure OpenAI service, you need to provide the resource URL and key for REST API. You also |
|
need to set the API version or just use the default value. The Azure OpenAI model which is deployed needs to be set in each tab. Because |
|
you may need to run different models at the same time. Don't forget to hit 'Enter' with every input. <br> For Azure Cognitive services, |
|
you need to provide a Key for REST API, and also need to provide a service region, for example, 'westus'. The app will create the endpoint |
|
URL by itself. <br><br>Thank you.<br><br><br> Azure OpenAI Service parameters for ChatGPT/GPT. Please input these settings and hit the |
|
'Enter' key. |
|
""") |
|
with gr.Row(): |
|
with gr.Column(scale=0.6): |
|
ui_aoai_url = gr.Textbox(placeholder="Like https://your-url-base.openai.azure.com , etc.", |
|
label="- Azure OpenAI service API endpoint:", lines=1).style(container=False) |
|
with gr.Column(scale=0.2): |
|
ui_aoai_key = gr.Textbox(placeholder="Please enter your Azure OpenAI API key here.", |
|
label="- Azure OpenAI service API Key: ", lines=1, type='password').style(container=False) |
|
with gr.Column(scale=0.2): |
|
ui_aoai_api = gr.Textbox(value="2023-03-15-preview", label="· Azure OpenAI service API version: ", |
|
lines=1, interactive=True).style(container=False) |
|
gr.HTML("Azure Cognitive Speech Service parameters to use VoiceChat. ") |
|
with gr.Row(): |
|
with gr.Column(scale=0.6): |
|
ui_stts_key = gr.Textbox(placeholder="Please enter your speech service API key if you want to try VoiceChat. " + |
|
"Please input these settings and hit 'Enter' key.", |
|
label="- Azure Cognitive Speech service API Key: ", interactive=True, type='password').style(container=False) |
|
with gr.Column(scale=0.4): |
|
ui_stts_loc = gr.Textbox(placeholder="Please enter your speech service region.", |
|
label="- Azure Cognitive Speech service region: ", interactive=True).style(container=False) |
|
ui_aoai_url.submit(get_aoai_set, [ui_aoai_url, ui_aoai_key, ui_aoai_api], [ui_aoai_url, ui_aoai_key, ui_aoai_api]) |
|
ui_aoai_key.submit(get_aoai_set, [ui_aoai_url, ui_aoai_key, ui_aoai_api], [ui_aoai_url, ui_aoai_key, ui_aoai_api]) |
|
ui_aoai_api.submit(get_aoai_set, [ui_aoai_url, ui_aoai_key, ui_aoai_api], [ui_aoai_url, ui_aoai_key, ui_aoai_api]) |
|
ui_stts_key.submit(get_stts_set, [ui_stts_key, ui_stts_loc], [ui_stts_key, ui_stts_loc]) |
|
ui_stts_loc.submit(get_stts_set, [ui_stts_key, ui_stts_loc], [ui_stts_key, ui_stts_loc]) |
|
|
|
with gr.TabItem("GPT-3.5 Playground"): |
|
ui_chatbot_gpt = gr.Chatbot(label="GPT Playground:") |
|
with gr.Row(): |
|
with gr.Column(scale=0.9): |
|
ui_prompt_gpt = gr.Textbox(placeholder="Please enter your prompt here.", show_label=False).style(container=False) |
|
with gr.Column(scale=0.1, min_width=100): |
|
ui_clear_gpt = gr.Button("Clear Input", ) |
|
with gr.Accordion("Expand to config parameters:", open=True): |
|
ui_memo_gpt = gr.HTML("GPT-3.5 playground use Completion(). So you just need to provide model name as engine parameter.") |
|
ui_model_gpt = gr.Textbox(placeholder="Azure OpenAI GPT model deployment name. ", |
|
label="- Azure OpenAI deployment name:", lines=1).style(container=False) |
|
with gr.Row(): |
|
ui_temp_gpt = gr.Slider(0.1, 1.0, 0.9, step=0.1, label="Temperature", interactive=True) |
|
ui_max_tokens_gpt = gr.Slider(100, 4000, 1000, step=100, label="Max Tokens", interactive=True) |
|
ui_top_p_gpt = gr.Slider(0.1, 1.0, 0.5, step=0.1, label="Top P", interactive=True) |
|
with gr.Accordion("Select radio button to see detail:", open=False): |
|
ui_res_radio_gpt = gr.Radio(["Response from OpenAI Model", "Prompt messages history"], label="Show OpenAI response:", interactive=True) |
|
ui_response_gpt = gr.TextArea(show_label=False, interactive=False).style(container=False) |
|
|
|
def get_parameters_gpt(slider_1, slider_2, slider_3): |
|
ui_temp_gpt.value = slider_1 |
|
ui_max_tokens_gpt.value = slider_2 |
|
ui_top_p_gpt.value = slider_3 |
|
print("Log - Updated GPT parameters: Temperature=", ui_temp_gpt.value, |
|
" Max Tokens=", ui_max_tokens_gpt.value, " Top_P=", ui_top_p_gpt.value) |
|
|
|
def get_engine_gpt(get_aoai_model): |
|
global model_gpt |
|
model_gpt = get_aoai_model |
|
return gr.update(value=get_aoai_model) |
|
|
|
def select_response_gpt(radio): |
|
if radio == "Response from OpenAI Model": |
|
return gr.update(value=gpt_x) |
|
else: |
|
return gr.update(value=messages_gpt) |
|
|
|
def user_gpt(user_message, history): |
|
global prompts |
|
prompts = user_message |
|
messages_gpt.append(prompts) |
|
return "", history + [[user_message, None]] |
|
|
|
def bot_gpt(history): |
|
global gpt_x |
|
print(ui_model_gpt.value) |
|
gpt_x = openai.Completion.create( |
|
engine=model_gpt, |
|
prompt=prompts, |
|
temperature=0.6, |
|
max_tokens=1000, |
|
top_p=1, |
|
frequency_penalty=0, |
|
presence_penalty=0, |
|
best_of=1, |
|
stop=None |
|
) |
|
gpt_reply = gpt_x.choices[0].text |
|
messages_gpt.append(gpt_reply) |
|
history[-1][1] = gpt_reply |
|
return history |
|
|
|
ui_model_gpt.submit(get_engine_gpt, ui_model_gpt , ui_model_gpt) |
|
ui_temp_gpt.change(get_parameters_gpt, [ui_temp_gpt, ui_max_tokens_gpt, ui_top_p_gpt]) |
|
ui_max_tokens_gpt.change(get_parameters_gpt, [ui_temp_gpt, ui_max_tokens_gpt, ui_top_p_gpt]) |
|
ui_top_p_gpt.change(get_parameters_gpt, [ui_temp_gpt, ui_max_tokens_gpt, ui_top_p_gpt]) |
|
ui_prompt_gpt.submit(user_gpt, [ui_prompt_gpt, ui_chatbot_gpt], [ui_prompt_gpt, ui_chatbot_gpt], queue=False).then( |
|
bot_gpt, ui_chatbot_gpt, ui_chatbot_gpt |
|
) |
|
ui_clear_gpt.click(lambda: None, None, ui_chatbot_gpt, queue=False) |
|
ui_res_radio_gpt.change(select_response_gpt, ui_res_radio_gpt, ui_response_gpt) |
|
|
|
with gr.TabItem("ChatGPT on GPT-4"): |
|
ui_chatbot_chat = gr.Chatbot(label="ChatGPT:") |
|
with gr.Row(): |
|
with gr.Column(scale=0.9): |
|
ui_prompt_chat = gr.Textbox(placeholder="Please enter your prompt here.", show_label=False).style(container=False) |
|
with gr.Column(scale=0.1, min_width=100): |
|
ui_clear_chat = gr.Button("Clear Chat") |
|
with gr.Blocks(): |
|
with gr.Accordion("Expand to config parameters:", open=True): |
|
gr.HTML("ChatGPT use ChatCompletion(). Here is the default system prompt, you can change it to your own prompt.") |
|
ui_prompt_sys = gr.Textbox(value="You are an AI assistant that helps people find information.", |
|
label="- Here is the default system prompt, you can change it to your own prompt.", |
|
interactive=True).style(container=False) |
|
ui_model_chat = gr.Textbox(placeholder="Azure OpenAI model deployment name. ", |
|
label="- Azure OpenAI GPT-3.5/4 deployment name:", lines=1).style(container=False) |
|
with gr.Row(): |
|
ui_temp_chat = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="Temperature", interactive=True) |
|
ui_max_tokens_chat = gr.Slider(100, 8000, 2000, step=100, label="Max Tokens", interactive=True) |
|
ui_top_p_chat = gr.Slider(0.05, 1.0, 0.9, step=0.1, label="Top P", interactive=True) |
|
with gr.Accordion("Select radio button to see detail:", open=False): |
|
ui_res_radio_chat = gr.Radio(["Response from OpenAI Model", "Prompt messages history"], label="Show OpenAI response:", interactive=True) |
|
ui_response_chat = gr.TextArea(show_label=False, interactive=False).style(container=False) |
|
|
|
def get_parameters_chat(slider_1, slider_2, slider_3): |
|
ui_temp_chat.value = slider_1 |
|
ui_max_tokens_chat.value = slider_2 |
|
ui_top_p_chat.value = slider_3 |
|
print("Log - Updated chatGPT parameters: Temperature=", ui_temp_chat.value, |
|
" Max Tokens=", ui_max_tokens_chat.value, " Top_P=", ui_top_p_chat.value) |
|
|
|
def get_engine_chat(get_aoai_model): |
|
global model_chat |
|
model_chat = get_aoai_model |
|
return gr.update(value=get_aoai_model) |
|
|
|
def select_response_chat(radio): |
|
if radio == "Response from OpenAI Model": |
|
return gr.update(value=chat_x) |
|
else: |
|
return gr.update(value=messages_chat) |
|
|
|
def user_chat(user_message, history): |
|
messages_chat.append({"role": "user", "content": user_message}) |
|
return "", history + [[user_message, None]] |
|
|
|
def bot_chat(history): |
|
global chat_x |
|
chat_x = openai.ChatCompletion.create( |
|
engine=model_chat, messages=messages_chat, |
|
temperature=ui_temp_chat.value, |
|
max_tokens=ui_max_tokens_chat.value, |
|
top_p=ui_top_p_chat.value, |
|
frequency_penalty=0, |
|
presence_penalty=0, |
|
stop=None |
|
) |
|
|
|
ui_response_chat.value= chat_x |
|
print(ui_response_chat.value) |
|
|
|
chat_reply = chat_x.choices[0].message.content |
|
messages_chat.append({"role": "assistant", "content": chat_reply}) |
|
|
|
history[-1][1] = chat_reply |
|
return history |
|
|
|
def reset_sys(sysmsg): |
|
global messages_chat |
|
messages_chat = [ |
|
{"role": "system", "content": sysmsg}, |
|
] |
|
|
|
ui_model_chat.submit(get_engine_chat, ui_model_chat, ui_model_chat) |
|
ui_res_radio_chat.change(select_response_chat, ui_res_radio_chat, ui_response_chat) |
|
ui_temp_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) |
|
ui_max_tokens_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) |
|
ui_top_p_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) |
|
ui_prompt_sys.submit(reset_sys, ui_prompt_sys) |
|
ui_prompt_chat.submit(user_chat, [ui_prompt_chat, ui_chatbot_chat], [ui_prompt_chat, ui_chatbot_chat], queue=False).then( |
|
bot_chat, ui_chatbot_chat, ui_chatbot_chat |
|
) |
|
ui_clear_chat.click(lambda: None, None, ui_chatbot_chat, queue=False).then(reset_sys, ui_prompt_sys) |
|
|
|
|
|
with gr.TabItem("DALL·E 2 Painting"): |
|
ui_prompt_walle = gr.Textbox(placeholder="Please enter your prompt here to generate image.", |
|
show_label=False).style(container=False) |
|
ui_image_walle = gr.Image() |
|
with gr.Accordion("Select radio button to see detail:", open=False): |
|
ui_response_walle = gr.TextArea(show_label=False, interactive=False).style(container=False) |
|
|
|
def get_image_walle(prompt_walle): |
|
global response_walle |
|
walle_api_version = '2022-08-03-preview' |
|
url = "{}dalle/text-to-image?api-version={}".format(openai.api_base, walle_api_version) |
|
headers= { "api-key": openai.api_key, "Content-Type": "application/json" } |
|
body = { |
|
"caption": prompt_walle, |
|
"resolution": "1024x1024" |
|
} |
|
submission = requests.post(url, headers=headers, json=body) |
|
response_walle.append(submission.json()) |
|
print("Log - WALL·E status: {}".format(submission.json())) |
|
operation_location = submission.headers['Operation-Location'] |
|
retry_after = submission.headers['Retry-after'] |
|
status = "" |
|
while (status != "Succeeded"): |
|
time.sleep(int(retry_after)) |
|
response = requests.get(operation_location, headers=headers) |
|
response_walle.append(response.json()) |
|
print("Log - WALL·E status: {}".format(response.json())) |
|
status = response.json()['status'] |
|
image_url_walle = response.json()['result']['contentUrl'] |
|
return gr.update(value=image_url_walle) |
|
|
|
def get_response_walle(): |
|
global response_walle |
|
return gr.update(value=response_walle) |
|
|
|
ui_prompt_walle.submit(get_image_walle, ui_prompt_walle, ui_image_walle, queue=False).then(get_response_walle, None, ui_response_walle) |
|
|
|
with gr.TabItem("VoiceChat on GPT"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Accordion("Expand to config parameters:", open=True): |
|
ui_prompt_sys_vchat = gr.Textbox(value="You are an AI assistant that helps people find information and just respond with SSML.", |
|
label="- Here is the default system prompt, you can change it to your own prompt.", |
|
interactive=True).style(container=False) |
|
ui_model_vchat = gr.Textbox(placeholder="- Azure OpenAI model deployment name. ", |
|
label="- Azure OpenAI GPT-3.5/4 deployment name:", lines=1).style(container=False) |
|
|
|
ui_voice_inc_vchat = gr.Audio(source="microphone", type="filepath") |
|
ui_voice_out_vchat = gr.Audio(value=None, type="filepath", interactive=False).style(container=False) |
|
with gr.Accordion("Expand to config parameters:", open=False): |
|
with gr.Row(): |
|
ui_temp_vchat = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="Temperature", interactive=True) |
|
ui_max_tokens_vchat = gr.Slider(100, 8000, 800, step=100, label="Max Tokens", interactive=True) |
|
ui_top_p_vchat = gr.Slider(0.05, 1.0, 0.9, step=0.1, label="Top P", interactive=True) |
|
with gr.Column(): |
|
ui_chatbot_vchat = gr.Chatbot(label="Voice to ChatGPT:") |
|
with gr.Accordion("Select radio button to see detail:", open=False): |
|
ui_res_radio_vchat = gr.Radio(["Response from OpenAI Model", "Prompt messages history"], label="Show OpenAI response:", interactive=True) |
|
ui_response_vchat = gr.TextArea(show_label=False, interactive=False).style(container=False) |
|
|
|
def get_parameters_vchat(slider_1, slider_2, slider_3): |
|
ui_temp_vchat.value = slider_1 |
|
ui_max_tokens_vchat.value = slider_2 |
|
ui_top_p_vchat.value = slider_3 |
|
print("Log - Updated chatGPT parameters: Temperature=", ui_temp_vchat.value, |
|
" Max Tokens=", ui_max_tokens_vchat.value, " Top_P=", ui_top_p_vchat.value) |
|
|
|
def get_engine_vchat(get_aoai_model): |
|
global model_vchat |
|
model_vchat = get_aoai_model |
|
return gr.update(value=get_aoai_model) |
|
|
|
def select_response_vchat(radio): |
|
if radio == "Response from OpenAI Model": |
|
return gr.update(value=vchat_x) |
|
else: |
|
return gr.update(value=messages_vchat) |
|
|
|
def speech_to_text(voice_message): |
|
|
|
voice_wav = am.from_file(voice_message, format='wav') |
|
voice_wav = voice_wav.set_frame_rate(16000) |
|
voice_wav.export(voice_message, format='wav') |
|
|
|
service_region = stts_region |
|
|
|
base_url = "https://"+service_region+".stt.speech.microsoft.com/" |
|
path = 'speech/recognition/conversation/cognitiveservices/v1' |
|
constructed_url = base_url + path |
|
params = { |
|
'language': 'zh-CN', |
|
'format': 'detailed' |
|
} |
|
headers = { |
|
'Ocp-Apim-Subscription-Key': stts_key, |
|
'Content-Type': 'audio/wav; codecs=audio/pcm; samplerate=16000', |
|
'Accept': 'application/json;text/xml' |
|
} |
|
body = open(voice_message,'rb').read() |
|
response = requests.post(constructed_url, params=params, headers=headers, data=body) |
|
if response.status_code == 200: |
|
rs = response.json() |
|
if rs != '': |
|
print(rs) |
|
else: |
|
print("\nLog - Status code: " + str(response.status_code) + "\nSomething went wrong. Check your subscription key and headers.\n") |
|
print("Reason: " + str(response.reason) + "\n") |
|
|
|
sst_text = rs['DisplayText'] |
|
return sst_text |
|
|
|
def text_to_speech(): |
|
service_region = stts_region |
|
|
|
print(stts_key) |
|
base_url = "https://"+service_region+".tts.speech.microsoft.com/" |
|
path = 'cognitiveservices/v1' |
|
constructed_url = base_url + path |
|
headers = { |
|
'Ocp-Apim-Subscription-Key': stts_key, |
|
'Content-Type': 'application/ssml+xml', |
|
'X-Microsoft-OutputFormat': 'riff-24khz-16bit-mono-pcm', |
|
'User-Agent': 'Voice ChatGPT' |
|
} |
|
xml_body = ElementTree.Element('speak', version='1.0') |
|
xml_body.set('{http://www.w3.org/XML/1998/namespace}lang', 'zh-cn') |
|
voice = ElementTree.SubElement(xml_body, 'voice') |
|
voice.set('{http://www.w3.org/XML/1998/namespace}lang', 'zh-cn') |
|
voice.set('name', 'zh-CN-XiaoxiaoNeural') |
|
voice.text = vchat_reply |
|
body = ElementTree.tostring(xml_body) |
|
response = requests.post(constructed_url, headers=headers, data=body) |
|
if response.status_code == 200: |
|
with open('chatgpt.wav', 'wb') as audio: |
|
audio.write(response.content) |
|
print("\nStatus code: " + str(response.status_code) + "\nYour TTS is ready for playback.\n") |
|
else: |
|
print("\nStatus code: " + str(response.status_code) + "\nSomething went wrong. Check your subscription key and headers.\n") |
|
print("Reason: " + str(response.reason) + "\n") |
|
|
|
tts_file = "chatgpt.wav" |
|
return gr.update(value=tts_file, interactive=True) |
|
|
|
def user_vchat(user_voice_message, history): |
|
user_message = speech_to_text(user_voice_message) |
|
messages_vchat.append({"role": "user", "content": user_message}) |
|
return history + [[user_message, None]] |
|
|
|
def bot_vchat(history): |
|
global vchat_x, vchat_reply |
|
vchat_x = openai.ChatCompletion.create( |
|
engine=model_vchat, messages=messages_vchat, |
|
temperature=ui_temp_chat.value, |
|
max_tokens=ui_max_tokens_chat.value, |
|
top_p=ui_top_p_chat.value, |
|
frequency_penalty=0, |
|
presence_penalty=0, |
|
stop=None |
|
) |
|
ui_response_vchat.value= vchat_x |
|
print(ui_response_vchat.value) |
|
vchat_reply = vchat_x.choices[0].message.content |
|
messages_vchat.append({"role": "assistant", "content": vchat_reply}) |
|
history[-1][1] = vchat_reply |
|
return history |
|
|
|
ui_model_vchat.submit(get_engine_vchat, ui_model_vchat, ui_model_vchat) |
|
ui_res_radio_vchat.change(select_response_vchat, ui_res_radio_vchat, ui_response_vchat) |
|
ui_temp_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) |
|
ui_max_tokens_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) |
|
ui_top_p_chat.change(get_parameters_chat, [ui_temp_chat, ui_max_tokens_chat, ui_top_p_chat]) |
|
ui_voice_inc_vchat.change(user_vchat, [ui_voice_inc_vchat, ui_chatbot_vchat], ui_chatbot_vchat, queue=False).then( |
|
bot_vchat, ui_chatbot_vchat, ui_chatbot_vchat, queue=False).then(text_to_speech, None, ui_voice_out_vchat) |
|
|
|
|
|
page.launch(share=False) |