imseldrith commited on
Commit
a07ed46
1 Parent(s): eece182

Upload 3 files

Browse files
Files changed (3) hide show
  1. src/functions.py +53 -0
  2. src/main.py +101 -0
  3. src/schemas.py +9 -0
src/functions.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+
4
+ def get_time_utc(zone ,delay=0):
5
+ loc_time = time.gmtime(time.time() + delay + zone * 60 * 60)
6
+ return time.strftime("%Y-%m-%d %H:%M:%S",loc_time)
7
+
8
+ def clear_dict(d):
9
+ if d is None:
10
+ return None
11
+ elif isinstance(d, list):
12
+ return list(filter(lambda x: x is not None, map(clear_dict, d)))
13
+ elif not isinstance(d, dict):
14
+ return d
15
+ else:
16
+ r = dict(
17
+ filter(lambda x: x[1] is not None,
18
+ map(lambda x: (x[0], clear_dict(x[1])),
19
+ d.items())))
20
+ if not bool(r):
21
+ return None
22
+ return r
23
+
24
+ def print_env(server_port=6006, sleep=3):
25
+ print("")
26
+ print("")
27
+ print("###########################################")
28
+ print("environment variable start-----------------------------------")
29
+ print("###########################################")
30
+ print("")
31
+
32
+ print("server_port: " + str(server_port))
33
+
34
+ print("")
35
+ print("###########################################")
36
+ print("Please check the environment variables (the program will start in 3 seconds) ↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑")
37
+ print("###########################################")
38
+ print("")
39
+ print("")
40
+ time.sleep(sleep)
41
+ return
42
+
43
+ def print_log(request, respose, time_start=0):
44
+ print("______________________________________________")
45
+ print("request" + ":::\n" + json.dumps(clear_dict(request.__dict__))) # class dict convert to json
46
+ print("respose" + ":::")
47
+ if isinstance(respose, dict) or isinstance(respose,list):
48
+ print(respose)
49
+ else:
50
+ print(respose.__dict__)
51
+ print("cost:::\n" + str(time.time() - time_start) + "s")
52
+ print("finish:::\n" + get_time_utc(-8))
53
+ return
src/main.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ import schemas
3
+ import uvicorn
4
+ from starlette.middleware.cors import CORSMiddleware
5
+ from functions import *
6
+ import base64
7
+ import os
8
+ import traceback
9
+
10
+ from bark import SAMPLE_RATE, generate_audio, preload_models
11
+ import soundfile as sf
12
+ import wave
13
+ import numpy as np
14
+ import nltk
15
+
16
+ # fastapi port
17
+ server_port = 6006
18
+
19
+ # Preload model
20
+ preload_models()
21
+
22
+ app = FastAPI(docs_url=None, redoc_url=None)
23
+
24
+ # Set allowed access domain names
25
+ origins = ["*"] # set to "*" means all.
26
+
27
+
28
+ def concatenate_wavs(wav_files, output_file, silence_duration=0.3):
29
+ wavs = [wave.open(f, 'rb') for f in wav_files]
30
+ sampwidth = wavs[0].getsampwidth()
31
+ framerate = wavs[0].getframerate()
32
+ nchannels = wavs[0].getnchannels()
33
+
34
+ samples = [wav.readframes(wav.getnframes()) for wav in wavs]
35
+ total_frames = sum(len(s) for s in samples) + int(silence_duration * framerate * nchannels * sampwidth)
36
+
37
+ output = wave.open(output_file, 'wb')
38
+ output.setparams((nchannels, sampwidth, framerate, total_frames, 'NONE', 'Uncompressed'))
39
+
40
+ for s in samples:
41
+ output.writeframes(s)
42
+ silence_frame = np.zeros((int(silence_duration * framerate), 2)).astype(np.int16).tobytes()
43
+ for i in range(int(nchannels / 2)):
44
+ output.writeframes(silence_frame)
45
+
46
+ output.close()
47
+
48
+
49
+ # Set cross domain parameter transfer
50
+ app.add_middleware(
51
+ CORSMiddleware,
52
+ allow_origins=origins, # Set allowed origins sources
53
+ allow_credentials=True,
54
+ allow_methods=["*"], # Set up HTTP methods that allow cross domain access, such as get, post, put, etc.
55
+ allow_headers=["*"]) # Allowing cross domain headers can be used to identify sources and other functions.
56
+
57
+
58
+ @app.post("/tts_bark/")
59
+ async def tts_bark(item: schemas.generate_web):
60
+ time_start = time.time()
61
+ text = item.text
62
+ print(f"{text=}")
63
+ try:
64
+ sentences = nltk.sent_tokenize(text)
65
+ idx = 1
66
+ wavs = []
67
+ for s in sentences:
68
+ audio_array = generate_audio(s, history_prompt="en_speaker_8", text_temp=0.6, waveform_temp=0.6)
69
+ fname = f"tmp-{idx}.wav"
70
+ sf.write(fname, audio_array, SAMPLE_RATE)
71
+ idx += 1
72
+ wavs.append(fname)
73
+ file_name_pre = f"out-{time.time()}"
74
+ file_name_wav = file_name_pre + ".wav"
75
+ file_name_ogg = file_name_pre + ".ogg"
76
+ concatenate_wavs(wavs, file_name_wav)
77
+
78
+ # convert to OGG
79
+ os.system("ffmpeg -i " + file_name_wav + " -c:a libopus -b:a 64k -y " + file_name_ogg)
80
+
81
+ with open(file_name_ogg, "rb") as f:
82
+ audio_content = f.read()
83
+ base64_audio = base64.b64encode(audio_content).decode("utf-8")
84
+ res = {"file_base64": base64_audio,
85
+ "audio_text": text,
86
+ "file_name": file_name_ogg,
87
+ }
88
+ print_log(item, res, time_start)
89
+ os.remove(file_name_wav)
90
+ os.remove(file_name_ogg)
91
+
92
+ return res
93
+ except Exception as err:
94
+ res = {"code": 9, "msg": "api error", "err": str(err), "traceback": traceback.format_exc()}
95
+ print_log(item, res, time_start)
96
+ return res
97
+
98
+ if __name__ == '__main__':
99
+
100
+ print_env(server_port)
101
+ uvicorn.run(app="main:app", host="0.0.0.0", port=server_port, reload=False)
src/schemas.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Union
3
+
4
+
5
+ class generate_web(BaseModel):
6
+ text: Union[str, None] = None
7
+
8
+ class Config:
9
+ orm_mode = True