File size: 1,673 Bytes
a3d51f8
f372cec
a3d51f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d268f5
a3d51f8
 
 
 
f372cec
 
 
5d268f5
 
 
f372cec
 
a3d51f8
 
571db8d
f94553c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import time
import base64
import gradio as gr
from sentence_transformers import SentenceTransformer

import httpx
import json

from utils import get_tags_for_prompts, get_mubert_tags_embeddings, get_pat

minilm = SentenceTransformer('all-MiniLM-L6-v2')
mubert_tags_embeddings = get_mubert_tags_embeddings(minilm)


def get_track_by_tags(tags, pat, duration, maxit=20, loop=False):
    if loop:
        mode = "loop"
    else:
        mode = "track"
    r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM',
                   json={
                       "method": "RecordTrackTTM",
                       "params": {
                           "pat": pat,
                           "duration": duration,
                           "tags": tags,
                           "mode": mode
                       }
                   })

    rdata = json.loads(r.text)
    assert rdata['status'] == 1, rdata['error']['text']
    trackurl = rdata['data']['tasks'][0]['download_link']

    print('Generating track ', end='')
    for i in range(maxit):
        r = httpx.get(trackurl)
        if r.status_code == 200:
            return trackurl
        time.sleep(1)


def generate_track_by_prompt(prompt):
    try:
        pat = get_pat("[email protected]")
        _, tags = get_tags_for_prompts(minilm, mubert_tags_embeddings, [prompt, ])[0]
        result = get_track_by_tags(tags, pat, int(30), loop=False)
        print(result)
        return result 
    except Exception as e:
        return str(e)


iface = gr.Interface(fn=generate_track_by_prompt, inputs=["text"], outputs=[gr.Text(label="Result")])
iface.queue(max_size=100, concurrency_count=20)
iface.launch()