File size: 4,083 Bytes
320951c
 
0bca911
4848335
 
 
 
 
 
 
 
 
 
 
320951c
 
 
4848335
 
be42941
 
4848335
 
 
 
 
64e1a6c
f034cbf
7965d4d
 
 
 
4848335
32f98c6
 
 
4848335
 
acb7417
4848335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320951c
 
acb7417
32f98c6
5dc9c2b
 
32f98c6
acb7417
be42941
 
64e1a6c
5dc9c2b
 
 
 
55fb629
5dc9c2b
 
 
 
 
 
 
 
 
 
 
55fb629
5dc9c2b
55fb629
 
 
 
5dc9c2b
55fb629
 
 
 
35032fa
55fb629
5dc9c2b
320951c
 
0bca911
b896898
 
 
08ba3e2
b896898
 
 
 
 
e576f9f
b896898
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os

import gradio as gr
import torch
import numpy as np
import librosa

from efficientat.models.MobileNetV3 import get_model as get_mobilenet, get_ensemble_model
from efficientat.models.preprocess import AugmentMelSTFT
from efficientat.helpers.utils import NAME_TO_WIDTH, labels

from torch import autocast
from contextlib import nullcontext

from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate
from langchain.chains.conversation.memory import ConversationalBufferWindowMemory

MODEL_NAME = "mn40_as"

session_token = os.environ["SESSION_TOKEN"]

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = get_mobilenet(width_mult=NAME_TO_WIDTH(MODEL_NAME), pretrained_name=MODEL_NAME)
model.to(device)
model.eval()    

chatgpt_chain = None
cached_audio_class = "c" 
template = None
prompt = None
chain = None


def format_classname(classname):
  return classname.capitalize()

def audio_tag(
    audio_path,
    human_input,
    sample_rate=32000,
    window_size=800,
    hop_size=320,
    n_mels=128,
    cuda=True,
):

    (waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True)
    mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size)
    mel.to(device)
    mel.eval()
    waveform = torch.from_numpy(waveform[None, :]).to(device)
    
    # our models are trained in half precision mode (torch.float16)
    # run on cuda with torch.float16 to get the best performance
    # running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse
    with torch.no_grad(), autocast(device_type=device.type) if cuda and torch.cuda.is_available() else nullcontext():
        spec = mel(waveform)
        preds, features = model(spec.unsqueeze(0))
    preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy()

    sorted_indexes = np.argsort(preds)[::-1]
    output = {}
    # Print audio tagging top probabilities

    label = labels[sorted_indexes[0]]
    return formatted_message(format_classname(label), human_input)

def format_classname(classname):
  return classname.capitalize()
    
def formatted_message(audio_class, human_input):
    global cached_audio_class 
    global session_token
    global chatgpt_chain
    formatted_classname = format_classname(audio_class)
    if cached_audio_class != formatted_classname:
        
        cached_audio_class = formatted_classname
    
        prefix = f"""You are going to act as a magical tool that allows for humans to communicate with non-human entities like rocks, crackling fire, trees, animals, and the wind. In order to do this, we're going to provide you the human's text input for the conversation. The goal is for you to embody that non-human entity and converse with the human.
Examples:
Non-human Entity: Tree
Human Input: Hello tree
Tree: Hello human, I am a tree
Let's begin:
Non-human Entity: {formatted_classname}"""
        suffix = f'''
{{history}}
Human Input: {{human_input}}
{formatted_classname}:'''
        template = prefix + suffix

        prompt = PromptTemplate(
            input_variables=["history", "human_input"], 
            template=template
        )

        chatgpt_chain = LLMChain(
            llm=OpenAI(temperature=.5, openai_api_key=session_token), 
            prompt=prompt, 
            verbose=True, 
            memory=ConversationalBufferWindowMemory(k=2, ai_prefix=formatted_classname),
        )
    output = chatgpt_chain.predict(human_input=human_input)

    return output

with gr.Blocks() as demo: 
    gr.HTML("""
    <div style='text-align: center; width:100%; margin: auto;'>
        <img src='logo.png' alt='anychat' width='250px' />
        <h3>Non-Human entities have many things to say, listen to them!</h3>
    </div>
    """)
    with gr.Row():
        with gr.Column():
            aud = gr.Audio(source="upload", type="filepath", label="Your audio")
            inp = gr.Textbox()
        out = gr.Textbox()
    btn = gr.Button("Run")
    btn.click(fn=audio_tag, inputs=[aud, inp], outputs=out)

    demo.launch()