Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
import numpy as np | |
import librosa | |
from efficientat.models.MobileNetV3 import get_model as get_mobilenet, get_ensemble_model | |
from efficientat.models.preprocess import AugmentMelSTFT | |
from efficientat.helpers.utils import NAME_TO_WIDTH, labels | |
from torch import autocast | |
from contextlib import nullcontext | |
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate | |
from langchain.chains.conversation.memory import ConversationalBufferWindowMemory | |
MODEL_NAME = "mn40_as" | |
session_token = os.environ["SESSION_TOKEN"] | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
model = get_mobilenet(width_mult=NAME_TO_WIDTH(MODEL_NAME), pretrained_name=MODEL_NAME) | |
model.to(device) | |
model.eval() | |
chatgpt_chain = None | |
cached_audio_class = "c" | |
template = None | |
prompt = None | |
chain = None | |
def format_classname(classname): | |
return classname.capitalize() | |
def audio_tag( | |
audio_path, | |
human_input, | |
sample_rate=32000, | |
window_size=800, | |
hop_size=320, | |
n_mels=128, | |
cuda=True, | |
): | |
(waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True) | |
mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size) | |
mel.to(device) | |
mel.eval() | |
waveform = torch.from_numpy(waveform[None, :]).to(device) | |
# our models are trained in half precision mode (torch.float16) | |
# run on cuda with torch.float16 to get the best performance | |
# running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse | |
with torch.no_grad(), autocast(device_type=device.type) if cuda and torch.cuda.is_available() else nullcontext(): | |
spec = mel(waveform) | |
preds, features = model(spec.unsqueeze(0)) | |
preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy() | |
sorted_indexes = np.argsort(preds)[::-1] | |
output = {} | |
# Print audio tagging top probabilities | |
label = labels[sorted_indexes[0]] | |
return formatted_message(format_classname(label), human_input) | |
def format_classname(classname): | |
return classname.capitalize() | |
def formatted_message(audio_class, human_input): | |
global cached_audio_class | |
global session_token | |
global chatgpt_chain | |
formatted_classname = format_classname(audio_class) | |
if cached_audio_class != formatted_classname: | |
cached_audio_class = formatted_classname | |
prefix = f"""You are going to act as a magical tool that allows for humans to communicate with non-human entities like rocks, crackling fire, trees, animals, and the wind. In order to do this, we're going to provide you the human's text input for the conversation. The goal is for you to embody that non-human entity and converse with the human. | |
Examples: | |
Non-human Entity: Tree | |
Human Input: Hello tree | |
Tree: Hello human, I am a tree | |
Let's begin: | |
Non-human Entity: {formatted_classname}""" | |
suffix = f''' | |
{{history}} | |
Human Input: {{human_input}} | |
{formatted_classname}:''' | |
template = prefix + suffix | |
prompt = PromptTemplate( | |
input_variables=["history", "human_input"], | |
template=template | |
) | |
chatgpt_chain = LLMChain( | |
llm=OpenAI(temperature=.5, openai_api_key=session_token), | |
prompt=prompt, | |
verbose=True, | |
memory=ConversationalBufferWindowMemory(k=2, ai_prefix=formatted_classname), | |
) | |
output = chatgpt_chain.predict(human_input=human_input) | |
return output | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<div style='text-align: center; width:100%; margin: auto;'> | |
<img src='logo.png' alt='anychat' width='250px' /> | |
<h3>Non-Human entities have many things to say, listen to them!</h3> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
aud = gr.Audio(source="upload", type="filepath", label="Your audio") | |
inp = gr.Textbox() | |
out = gr.Textbox() | |
btn = gr.Button("Run") | |
btn.click(fn=audio_tag, inputs=[aud, inp], outputs=out) | |
demo.launch() | |