Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
import numpy as np | |
import librosa | |
from efficientat.models.MobileNetV3 import get_model as get_mobilenet, get_ensemble_model | |
from efficientat.models.preprocess import AugmentMelSTFT | |
from efficientat.helpers.utils import NAME_TO_WIDTH, labels | |
from torch import autocast | |
from contextlib import nullcontext | |
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate | |
from langchain.chains.conversation.memory import ConversationalBufferWindowMemory | |
MODEL_NAME = "mn40_as" | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
model = get_mobilenet(width_mult=NAME_TO_WIDTH(MODEL_NAME), pretrained_name=MODEL_NAME) | |
model.to(device) | |
model.eval() | |
cached_audio_class = "c" | |
template = None | |
prompt = None | |
chain = None | |
def audio_tag( | |
audio_path, | |
sample_rate=32000, | |
window_size=800, | |
hop_size=320, | |
n_mels=128, | |
cuda=True, | |
): | |
(waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True) | |
mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size) | |
mel.to(device) | |
mel.eval() | |
waveform = torch.from_numpy(waveform[None, :]).to(device) | |
# our models are trained in half precision mode (torch.float16) | |
# run on cuda with torch.float16 to get the best performance | |
# running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse | |
with torch.no_grad(), autocast(device_type=device.type) if cuda and torch.cuda.is_available() else nullcontext(): | |
spec = mel(waveform) | |
preds, features = model(spec.unsqueeze(0)) | |
preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy() | |
sorted_indexes = np.argsort(preds)[::-1] | |
output = {} | |
# Print audio tagging top probabilities | |
label = labels[sorted_indexes[0]] | |
return formatted_message(label) | |
def formatted_message(audio_class): | |
if cached_audio_class != audio_class: | |
cached_audio_class = audio_class | |
prefix = '''You are going to act as a magical tool that allows for humans to communicate with non-human entities like | |
rocks, crackling fire, trees, animals, and the wind. In order to do this, we're going to provide you a data string which | |
represents the audio input, the source of the audio, and the human's text input for the conversation. | |
The goal is for you to embody the source of the audio, and use the length and variance in the signal data to produce | |
plausible responses to the humans input. Remember to embody the the source data. When we start the conversation, | |
you should generate a "personality profile" for the source and utilize that personality profile in your responses. | |
Let's begin:''' | |
suffix = f'''Source: {audio_class} | |
Length of Audio in Seconds: {audio_length} | |
Human Input: {userText} | |
{audio_class} Response:''' | |
template = prefix + suffix | |
prompt = PromptTemplate( | |
input_variables=["history", "human_input"], | |
template=template | |
) | |
chatgpt_chain = LLMChain( | |
llm=OpenAI(temperature=.5, openai_api_key=session_token), | |
prompt=prompt, | |
verbose=True, | |
memory=ConversationalBufferWindowMemory(k=2), | |
) | |
output = chatgpt_chain.predict(human_input=message) | |
return output | |
demo = gr.Interface( | |
audio_tag, | |
gr.Audio(source="upload", type="filepath", label="Your audio"), | |
gr.Textbox(), | |
examples=[["metro_station-paris.wav"]] | |
).launch(debug=True) | |