Spaces:
Runtime error
Runtime error
import torch | |
import json | |
import os | |
from transformers import pipeline, VitsModel, VitsTokenizer, SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor | |
import numpy as np | |
os.system("pip install git+https://github.com/openai/whisper.git") | |
import gradio as gr | |
import requests | |
MODEL = "gpt-3.5-turbo" | |
API_URL = os.getenv("API_URL") | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
NUM_THREADS = 2 | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
def parse_codeblock(text): | |
lines = text.split("\n") | |
for i, line in enumerate(lines): | |
if "```" in line: | |
if line != "```": | |
lines[i] = f'<pre><code class="{lines[i][3:]}">' | |
else: | |
lines[i] = '</code></pre>' | |
else: | |
if i > 0: | |
lines[i] = "<br/>" + line.replace("<", "<").replace(">", ">") | |
return "".join(lines) | |
pipe = pipeline(model="Sleepyp00/whisper-small-Swedish") | |
model2 = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng") | |
# Define a function to translate an audio, in english here | |
def translate(audio): | |
outputs = pipe(audio, max_new_tokens=256, | |
generate_kwargs={"task": "translate"}) | |
return outputs["text"] | |
# Define function to generate the waveform output | |
def synthesise(text): | |
inputs = tokenizer(text, return_tensors="pt") | |
input_ids = inputs["input_ids"] | |
with torch.no_grad(): | |
outputs = model2(input_ids) | |
return outputs.audio[0] | |
def gpt_predict(inputs, request:gr.Request=gr.State([]), top_p = 1, temperature = 1, chat_counter = 0,history =[]): | |
payload = { | |
"model": MODEL, | |
"messages": [{"role": "user", "content": f"{inputs}"}], | |
"temperature" : 1.0, | |
"top_p":1.0, | |
"n" : 1, | |
"stream": True, | |
"presence_penalty":0, | |
"frequency_penalty":0, | |
} | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {OPENAI_API_KEY}", | |
} | |
# print(f"chat_counter - {chat_counter}") | |
if chat_counter != 0 : | |
messages = [] | |
for i, data in enumerate(history): | |
if i % 2 == 0: | |
role = 'user' | |
else: | |
role = 'assistant' | |
message = {} | |
message["role"] = role | |
message["content"] = data | |
messages.append(message) | |
message = {} | |
message["role"] = "user" | |
message["content"] = inputs | |
messages.append(message) | |
payload = { | |
"model": MODEL, | |
"messages": messages, | |
"temperature" : temperature, | |
"top_p": top_p, | |
"n" : 1, | |
"stream": True, | |
"presence_penalty":0, | |
"frequency_penalty":0, | |
} | |
chat_counter += 1 | |
history.append(inputs) | |
token_counter = 0 | |
partial_words = "" | |
counter = 0 | |
try: | |
# make a POST request to the API endpoint using the requests.post method, passing in stream=True | |
response = requests.post(API_URL, headers=headers, json=payload, stream=True) | |
response_code = f"{response}" | |
#if response_code.strip() != "<Response [200]>": | |
# #print(f"response code - {response}") | |
# raise Exception(f"Sorry, hitting rate limit. Please try again later. {response}") | |
out = [] | |
for chunk in response.iter_lines(): | |
#Skipping first chunk | |
if counter == 0: | |
counter += 1 | |
continue | |
#counter+=1 | |
# check whether each line is non-empty | |
if chunk.decode() : | |
chunk = chunk.decode() | |
# decode each line as response data is in bytes | |
if len(chunk) > 12 and "content" in json.loads(chunk[6:])['choices'][0]['delta']: | |
partial_words = partial_words + json.loads(chunk[6:])['choices'][0]["delta"]["content"] | |
if token_counter == 0: | |
history.append(" " + partial_words) | |
else: | |
history[-1] = partial_words | |
token_counter += 1 | |
except Exception as e: | |
print (f'error found: {e}') | |
return partial_words | |
# Define the pipeline | |
def speech_to_speech_translation(audio): | |
translated_text = translate(audio) | |
synthesised_speech = synthesise(translated_text) | |
synthesised_speech = ( | |
synthesised_speech.numpy() * 32767).astype(np.int16) | |
return [translated_text, None, (16000, synthesised_speech)] | |
def predict(transType, language, audio, audio_mic = None): | |
print("debug1:", audio,"debug2", audio_mic) | |
if not audio and audio_mic: | |
audio = audio_mic | |
if transType == "Text": | |
return translate(audio), None, None | |
if transType == "GPT answer": | |
req = translate(audio) | |
st = gr.State([]) | |
return req, gpt_predict(req,st), None | |
if transType == "Audio": | |
return speech_to_speech_translation(audio) | |
# Define the title etc | |
title = "Swedish STSOT (Speech To Speech Or Text)" | |
description="Use Whisper pretrained model to convert swedish audio to english (text or audio)" | |
supportLangs = ["Swedish", "French (in training)"] | |
transTypes = ["Text", "Audio", "GPT answer"] | |
examples = [ | |
["Text", "Swedish", "./ex1.wav", None], | |
["Audio", "Swedish", "./ex2.wav", None] | |
] | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Radio(label="Choose your output format", choices=transTypes), | |
gr.Radio(label="Choose a source language", choices=supportLangs, value="Swedish"), | |
gr.Audio(label="Import an audio", sources="upload", type="filepath"), | |
#gr.Audio(label="Import an audio", sources="upload", type="numpy"), | |
gr.Audio(label="Record an audio", sources="microphone", type="filepath"), | |
], | |
outputs=[ | |
gr.Text(label="Text translation"),gr.Text(label="GPT answer"),gr.Audio(label="Audio translation",type = "numpy") | |
], | |
title=title, | |
description=description, | |
article="", | |
examples=examples, | |
) | |
demo.launch() |