Spaces:
Runtime error
Runtime error
File size: 4,083 Bytes
320951c 0bca911 4848335 320951c 4848335 be42941 4848335 64e1a6c f034cbf 7965d4d 4848335 32f98c6 4848335 acb7417 4848335 320951c acb7417 32f98c6 5dc9c2b 32f98c6 acb7417 be42941 64e1a6c 5dc9c2b 55fb629 5dc9c2b 55fb629 5dc9c2b 55fb629 5dc9c2b 55fb629 35032fa 55fb629 5dc9c2b 320951c 0bca911 b896898 08ba3e2 b896898 e576f9f b896898 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import gradio as gr
import torch
import numpy as np
import librosa
from efficientat.models.MobileNetV3 import get_model as get_mobilenet, get_ensemble_model
from efficientat.models.preprocess import AugmentMelSTFT
from efficientat.helpers.utils import NAME_TO_WIDTH, labels
from torch import autocast
from contextlib import nullcontext
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate
from langchain.chains.conversation.memory import ConversationalBufferWindowMemory
MODEL_NAME = "mn40_as"
session_token = os.environ["SESSION_TOKEN"]
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = get_mobilenet(width_mult=NAME_TO_WIDTH(MODEL_NAME), pretrained_name=MODEL_NAME)
model.to(device)
model.eval()
chatgpt_chain = None
cached_audio_class = "c"
template = None
prompt = None
chain = None
def format_classname(classname):
return classname.capitalize()
def audio_tag(
audio_path,
human_input,
sample_rate=32000,
window_size=800,
hop_size=320,
n_mels=128,
cuda=True,
):
(waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True)
mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size)
mel.to(device)
mel.eval()
waveform = torch.from_numpy(waveform[None, :]).to(device)
# our models are trained in half precision mode (torch.float16)
# run on cuda with torch.float16 to get the best performance
# running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse
with torch.no_grad(), autocast(device_type=device.type) if cuda and torch.cuda.is_available() else nullcontext():
spec = mel(waveform)
preds, features = model(spec.unsqueeze(0))
preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy()
sorted_indexes = np.argsort(preds)[::-1]
output = {}
# Print audio tagging top probabilities
label = labels[sorted_indexes[0]]
return formatted_message(format_classname(label), human_input)
def format_classname(classname):
return classname.capitalize()
def formatted_message(audio_class, human_input):
global cached_audio_class
global session_token
global chatgpt_chain
formatted_classname = format_classname(audio_class)
if cached_audio_class != formatted_classname:
cached_audio_class = formatted_classname
prefix = f"""You are going to act as a magical tool that allows for humans to communicate with non-human entities like rocks, crackling fire, trees, animals, and the wind. In order to do this, we're going to provide you the human's text input for the conversation. The goal is for you to embody that non-human entity and converse with the human.
Examples:
Non-human Entity: Tree
Human Input: Hello tree
Tree: Hello human, I am a tree
Let's begin:
Non-human Entity: {formatted_classname}"""
suffix = f'''
{{history}}
Human Input: {{human_input}}
{formatted_classname}:'''
template = prefix + suffix
prompt = PromptTemplate(
input_variables=["history", "human_input"],
template=template
)
chatgpt_chain = LLMChain(
llm=OpenAI(temperature=.5, openai_api_key=session_token),
prompt=prompt,
verbose=True,
memory=ConversationalBufferWindowMemory(k=2, ai_prefix=formatted_classname),
)
output = chatgpt_chain.predict(human_input=human_input)
return output
with gr.Blocks() as demo:
gr.HTML("""
<div style='text-align: center; width:100%; margin: auto;'>
<img src='logo.png' alt='anychat' width='250px' />
<h3>Non-Human entities have many things to say, listen to them!</h3>
</div>
""")
with gr.Row():
with gr.Column():
aud = gr.Audio(source="upload", type="filepath", label="Your audio")
inp = gr.Textbox()
out = gr.Textbox()
btn = gr.Button("Run")
btn.click(fn=audio_tag, inputs=[aud, inp], outputs=out)
demo.launch()
|