Mahiruoshi commited on
Commit
f4cadb2
1 Parent(s): 24fbdec

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +60 -19
server.py CHANGED
@@ -4,8 +4,8 @@ from pathlib import Path
4
 
5
  import logging
6
  import re_matching
7
-
8
- from flask import Flask, request, jsonify
9
  from flask_cors import CORS
10
 
11
  logging.getLogger("numba").setLevel(logging.WARNING)
@@ -28,7 +28,7 @@ from tqdm import tqdm
28
 
29
  import utils
30
  from config import config
31
-
32
  import torch
33
  import commons
34
  from text import cleaned_text_to_sequence, get_bert
@@ -44,9 +44,6 @@ import sys
44
  from scipy.io.wavfile import write
45
 
46
  net_g = None
47
-
48
-
49
- '''
50
  device = (
51
  "cuda:0"
52
  if torch.cuda.is_available()
@@ -56,8 +53,8 @@ device = (
56
  else "cpu"
57
  )
58
  )
59
- '''
60
- device = 'cpu'
61
 
62
  def get_net_g(model_path: str, device: str, hps):
63
  net_g = SynthesizerTrn(
@@ -161,8 +158,9 @@ def infer(
161
  del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
162
  if torch.cuda.is_available():
163
  torch.cuda.empty_cache()
164
- write("temp.wav", 44100, audio)
165
- return 'success'
 
166
 
167
  def is_japanese(string):
168
  for ch in string:
@@ -171,16 +169,29 @@ def is_japanese(string):
171
  return False
172
 
173
  def loadmodel(model):
174
- _ = net_g.eval()
175
- _ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
176
- return "success"
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  app = Flask(__name__)
179
  CORS(app)
180
- @app.route('/tts')
181
 
182
  def tts():
183
- # 这些没必要改
184
  speaker = request.args.get('speaker')
185
  sdp_ratio = float(request.args.get('sdp_ratio', 0.2))
186
  noise_scale = float(request.args.get('noise_scale', 0.6))
@@ -188,13 +199,41 @@ def tts():
188
  length_scale = float(request.args.get('length_scale', 1))
189
  emotion = request.args.get('emotion', 'happy')
190
  text = request.args.get('text')
191
- status = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale,sid = speaker, reference_audio=None, emotion=emotion)
192
- with open('temp.wav','rb') as bit:
193
- wav_bytes = bit.read()
 
 
 
 
 
 
 
 
 
 
 
 
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  headers = {
196
  'Content-Type': 'audio/wav',
197
- 'Text': status.encode('utf-8')}
198
  return wav_bytes, 200, headers
199
 
200
 
@@ -210,4 +249,6 @@ if __name__ == "__main__":
210
  )
211
  speaker_ids = hps.data.spk2id
212
  speakers = list(speaker_ids.keys())
 
 
213
  app.run(host="0.0.0.0", port=5000)
 
4
 
5
  import logging
6
  import re_matching
7
+ import uuid
8
+ from flask import Flask, request, jsonify, render_template_string
9
  from flask_cors import CORS
10
 
11
  logging.getLogger("numba").setLevel(logging.WARNING)
 
28
 
29
  import utils
30
  from config import config
31
+ import requests
32
  import torch
33
  import commons
34
  from text import cleaned_text_to_sequence, get_bert
 
44
  from scipy.io.wavfile import write
45
 
46
  net_g = None
 
 
 
47
  device = (
48
  "cuda:0"
49
  if torch.cuda.is_available()
 
53
  else "cpu"
54
  )
55
  )
56
+
57
+ #device = 'cpu'
58
 
59
  def get_net_g(model_path: str, device: str, hps):
60
  net_g = SynthesizerTrn(
 
158
  del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
159
  if torch.cuda.is_available():
160
  torch.cuda.empty_cache()
161
+ unique_filename = f"temp{uuid.uuid4()}.wav"
162
+ write(unique_filename, 44100, audio)
163
+ return unique_filename
164
 
165
  def is_japanese(string):
166
  for ch in string:
 
169
  return False
170
 
171
  def loadmodel(model):
172
+ try:
173
+ _ = net_g.eval()
174
+ _ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
175
+ return "success"
176
+ except:
177
+ return "error"
178
+
179
+ def send_audio_to_server(audio_path,text):
180
+ url="http://127.0.0.1:3000/response"
181
+ files = {'file': open(audio_path, 'rb')}
182
+ data = {'text': text}
183
+ try:
184
+ response = requests.post(url, files=files,data=data)
185
+ return response.status_code, response.text
186
+ except Exception as e:
187
+ return 500, str(e)
188
 
189
  app = Flask(__name__)
190
  CORS(app)
191
+ @app.route('/')
192
 
193
  def tts():
194
+ global last_text, last_model
195
  speaker = request.args.get('speaker')
196
  sdp_ratio = float(request.args.get('sdp_ratio', 0.2))
197
  noise_scale = float(request.args.get('noise_scale', 0.6))
 
199
  length_scale = float(request.args.get('length_scale', 1))
200
  emotion = request.args.get('emotion', 'happy')
201
  text = request.args.get('text')
202
+ is_chat = request.args.get('is_chat', 'false').lower() == 'true'
203
+ model = request.args.get('model',modelPaths[-1])
204
+
205
+ if not speaker or not text:
206
+ return render_template_string("""
207
+ <!DOCTYPE html>
208
+ <html>
209
+ <head>
210
+ <title>TTS API Documentation</title>
211
+ </head>
212
+ <body>
213
+ <iframe src="http://love.soyorin.top" style="width:100%; height:100vh; border:none;"></iframe>
214
+ </body>
215
+ </html>
216
+ """)
217
 
218
+ if model != last_model:
219
+ unique_filename = loadmodel(model)
220
+ last_model = model
221
+ if is_chat and text == last_text:
222
+ # Generate 1 second of silence and return
223
+ unique_filename = 'blank.wav'
224
+ silence = np.zeros(44100, dtype=np.int16)
225
+ write(unique_filename , 44100, silence)
226
+ else:
227
+ last_text = text
228
+ unique_filename = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale,sid = speaker, reference_audio=None, emotion=emotion)
229
+ status_code, response_text = send_audio_to_server(unique_filename,text)
230
+ print(f"Response from server: {response_text} (Status code: {status_code})")
231
+ with open(unique_filename ,'rb') as bit:
232
+ wav_bytes = bit.read()
233
+ os.remove(unique_filename)
234
  headers = {
235
  'Content-Type': 'audio/wav',
236
+ 'Text': unique_filename .encode('utf-8')}
237
  return wav_bytes, 200, headers
238
 
239
 
 
249
  )
250
  speaker_ids = hps.data.spk2id
251
  speakers = list(speaker_ids.keys())
252
+ last_text = ""
253
+ last_model = modelPaths[-1]
254
  app.run(host="0.0.0.0", port=5000)