Jofthomas HF staff commited on
Commit
c76a369
1 Parent(s): 695a847

Update TextGen/router.py

Browse files
Files changed (1) hide show
  1. TextGen/router.py +23 -21
TextGen/router.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import time
3
  from langchain_core.pydantic_v1 import BaseModel, Field
4
  from fastapi import FastAPI, HTTPException, Query, Request
5
- from fastapi.responses import FileResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
 
8
  from langchain.chains import LLMChain
@@ -128,27 +128,29 @@ def determine_vocie_from_npc(npc,genre):
128
  return "./voices/narator_out.wav"
129
 
130
  @app.post("/generate_wav")
131
- async def generate_wav(message:VoiceMessage):
132
  try:
133
- voice=determine_vocie_from_npc(message.npc, message.genre)
134
- # Use the Gradio client to generate the wav file
135
- result = tts_client.predict(
136
- prompt=message.input,
137
- language=message.language,
138
- audio_file_pth=handle_file(voice),
139
- mic_file_path=None,
140
- use_mic=False,
141
- voice_cleanup=False,
142
- no_lang_auto_detect=False,
143
- agree=True,
144
- api_name="/predict"
145
- )
146
-
147
- # Get the path of the generated wav file
148
- wav_file_path = result
149
-
150
- # Return the generated wav file as a response
151
- return FileResponse(wav_file_path, media_type="audio/wav", filename="output.wav")
 
 
152
 
153
  except Exception as e:
154
  raise HTTPException(status_code=500, detail=str(e))
 
2
  import time
3
  from langchain_core.pydantic_v1 import BaseModel, Field
4
  from fastapi import FastAPI, HTTPException, Query, Request
5
+ from fastapi.responses import StreamingResponse
6
  from fastapi.middleware.cors import CORSMiddleware
7
 
8
  from langchain.chains import LLMChain
 
128
  return "./voices/narator_out.wav"
129
 
130
  @app.post("/generate_wav")
131
+ async def generate_wav(message: VoiceMessage):
132
  try:
133
+ voice = determine_voice_from_npc(message.npc, message.genre)
134
+ audio_file_pth = handle_file(voice)
135
+
136
+ # Generator function to yield audio chunks
137
+ def audio_stream():
138
+ result = tts_client.predict(
139
+ prompt=message.input,
140
+ language=message.language,
141
+ audio_file_pth=audio_file_pth,
142
+ mic_file_path=None,
143
+ use_mic=False,
144
+ voice_cleanup=False,
145
+ no_lang_auto_detect=False,
146
+ agree=True,
147
+ api_name="/predict"
148
+ )
149
+ for sampling_rate, audio_chunk in result:
150
+ yield audio_chunk.tobytes()
151
+
152
+ # Return the generated audio as a streaming response
153
+ return StreamingResponse(audio_stream(), media_type="audio/wav")
154
 
155
  except Exception as e:
156
  raise HTTPException(status_code=500, detail=str(e))