Update TextGen/router.py
Browse files- TextGen/router.py +54 -13
TextGen/router.py
CHANGED
@@ -9,7 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
9 |
from langchain.chains import LLMChain
|
10 |
from langchain.prompts import PromptTemplate
|
11 |
from TextGen.suno import custom_generate_audio, get_audio_information,generate_lyrics
|
12 |
-
from TextGen.diffusion import generate_image
|
13 |
#from coqui import predict
|
14 |
from langchain_google_genai import (
|
15 |
ChatGoogleGenerativeAI,
|
@@ -73,6 +73,11 @@ main_npcs={
|
|
73 |
"Herbalist":"./voices/female.mp3",
|
74 |
"Bard":"./voices/Bard_voice.mp3"
|
75 |
}
|
|
|
|
|
|
|
|
|
|
|
76 |
main_npc_system_prompts={
|
77 |
"Blacksmith":"You are a blacksmith in a video game",
|
78 |
"Herbalist":"You are an herbalist in a video game",
|
@@ -82,6 +87,10 @@ main_npc_system_prompts={
|
|
82 |
class Generate(BaseModel):
|
83 |
text:str
|
84 |
|
|
|
|
|
|
|
|
|
85 |
def generate_text(messages: List[str], npc:str):
|
86 |
print(npc)
|
87 |
if npc in main_npcs:
|
@@ -123,6 +132,24 @@ app.add_middleware(
|
|
123 |
allow_headers=["*"],
|
124 |
)
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
@app.get("/", tags=["Home"])
|
127 |
def api_home():
|
128 |
return {'detail': 'Everchanging Quest backend, nothing to see here'}
|
@@ -131,6 +158,10 @@ def api_home():
|
|
131 |
def inference(message: Message):
|
132 |
return generate_text(messages=message.messages, npc=message.npc)
|
133 |
|
|
|
|
|
|
|
|
|
134 |
#Dummy function for now
|
135 |
def determine_vocie_from_npc(npc,genre):
|
136 |
if npc in main_npcs:
|
@@ -142,7 +173,17 @@ def determine_vocie_from_npc(npc,genre):
|
|
142 |
return"./voices/default_female.mp3"
|
143 |
else:
|
144 |
return "./voices/narator_out.wav"
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
@app.post("/generate_wav")
|
148 |
async def generate_wav(message: VoiceMessage):
|
@@ -234,15 +275,15 @@ async def generate_song():
|
|
234 |
infos=get_audio_information(f"{data[0]['id']},{data[1]['id']}")
|
235 |
return infos
|
236 |
|
237 |
-
|
238 |
-
def Imagen(image:ImageGen=None):
|
239 |
-
pil_image =generate_image(image.prompt)
|
240 |
-
|
241 |
-
|
242 |
-
# Convert the PIL Image to bytes
|
243 |
-
img_byte_arr = BytesIO()
|
244 |
-
pil_image.save(img_byte_arr, format='PNG')
|
245 |
-
img_byte_arr = img_byte_arr.getvalue()
|
246 |
-
|
247 |
# Return the image as a PNG response
|
248 |
-
return Response(content=img_byte_arr, media_type="image/png")
|
|
|
9 |
from langchain.chains import LLMChain
|
10 |
from langchain.prompts import PromptTemplate
|
11 |
from TextGen.suno import custom_generate_audio, get_audio_information,generate_lyrics
|
12 |
+
#from TextGen.diffusion import generate_image
|
13 |
#from coqui import predict
|
14 |
from langchain_google_genai import (
|
15 |
ChatGoogleGenerativeAI,
|
|
|
73 |
"Herbalist":"./voices/female.mp3",
|
74 |
"Bard":"./voices/Bard_voice.mp3"
|
75 |
}
|
76 |
+
main_npcs_elevenlabs={
|
77 |
+
"Blacksmith":"",
|
78 |
+
"Herbalist":"",
|
79 |
+
"Bard":""
|
80 |
+
}
|
81 |
main_npc_system_prompts={
|
82 |
"Blacksmith":"You are a blacksmith in a video game",
|
83 |
"Herbalist":"You are an herbalist in a video game",
|
|
|
87 |
class Generate(BaseModel):
|
88 |
text:str
|
89 |
|
90 |
+
class Invoke(BaseModel):
|
91 |
+
system_prompt:str
|
92 |
+
message:str
|
93 |
+
|
94 |
def generate_text(messages: List[str], npc:str):
|
95 |
print(npc)
|
96 |
if npc in main_npcs:
|
|
|
132 |
allow_headers=["*"],
|
133 |
)
|
134 |
|
135 |
+
def inference_model(system_messsage, prompt):
|
136 |
+
|
137 |
+
new_messages=[{"role": "user", "content": system_messsage},{"role": "user", "content": prompt}]
|
138 |
+
llm = ChatGoogleGenerativeAI(
|
139 |
+
model="gemini-1.5-pro-latest",
|
140 |
+
max_output_tokens=100,
|
141 |
+
temperature=1,
|
142 |
+
safety_settings={
|
143 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
144 |
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
145 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
146 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
|
147 |
+
},
|
148 |
+
)
|
149 |
+
llm_response = llm.invoke(new_messages)
|
150 |
+
print(llm_response)
|
151 |
+
return Generate(text=llm_response.content)
|
152 |
+
|
153 |
@app.get("/", tags=["Home"])
|
154 |
def api_home():
|
155 |
return {'detail': 'Everchanging Quest backend, nothing to see here'}
|
|
|
158 |
def inference(message: Message):
|
159 |
return generate_text(messages=message.messages, npc=message.npc)
|
160 |
|
161 |
+
@app.post("/invoke_model", response_model=Generate)
|
162 |
+
def story(prompt: Invoke):
|
163 |
+
return inference_model(system_messsage=prompt.system_prompt,prompt=prompt.message)
|
164 |
+
|
165 |
#Dummy function for now
|
166 |
def determine_vocie_from_npc(npc,genre):
|
167 |
if npc in main_npcs:
|
|
|
173 |
return"./voices/default_female.mp3"
|
174 |
else:
|
175 |
return "./voices/narator_out.wav"
|
176 |
+
#Dummy function for now
|
177 |
+
def determine_elevenLav_voice_from_npc(npc,genre):
|
178 |
+
if npc in main_npcs:
|
179 |
+
return main_npcs[npc]
|
180 |
+
else:
|
181 |
+
if genre =="Male":
|
182 |
+
"./voices/default_male.mp3"
|
183 |
+
if genre=="Female":
|
184 |
+
return"./voices/default_female.mp3"
|
185 |
+
else:
|
186 |
+
return "./voices/narator_out.wav"
|
187 |
|
188 |
@app.post("/generate_wav")
|
189 |
async def generate_wav(message: VoiceMessage):
|
|
|
275 |
infos=get_audio_information(f"{data[0]['id']},{data[1]['id']}")
|
276 |
return infos
|
277 |
|
278 |
+
#@app.post('/generate_image')
|
279 |
+
#def Imagen(image:ImageGen=None):
|
280 |
+
# pil_image =generate_image(image.prompt)
|
281 |
+
#
|
282 |
+
#
|
283 |
+
# # Convert the PIL Image to bytes
|
284 |
+
# img_byte_arr = BytesIO()
|
285 |
+
# pil_image.save(img_byte_arr, format='PNG')
|
286 |
+
# img_byte_arr = img_byte_arr.getvalue()
|
287 |
+
#
|
288 |
# Return the image as a PNG response
|
289 |
+
# return Response(content=img_byte_arr, media_type="image/png")
|