Jofthomas HF staff commited on
Commit
163a5a5
1 Parent(s): 15798be

Update TextGen/router.py

Browse files
Files changed (1) hide show
  1. TextGen/router.py +54 -13
TextGen/router.py CHANGED
@@ -9,7 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
9
  from langchain.chains import LLMChain
10
  from langchain.prompts import PromptTemplate
11
  from TextGen.suno import custom_generate_audio, get_audio_information,generate_lyrics
12
- from TextGen.diffusion import generate_image
13
  #from coqui import predict
14
  from langchain_google_genai import (
15
  ChatGoogleGenerativeAI,
@@ -73,6 +73,11 @@ main_npcs={
73
  "Herbalist":"./voices/female.mp3",
74
  "Bard":"./voices/Bard_voice.mp3"
75
  }
 
 
 
 
 
76
  main_npc_system_prompts={
77
  "Blacksmith":"You are a blacksmith in a video game",
78
  "Herbalist":"You are an herbalist in a video game",
@@ -82,6 +87,10 @@ main_npc_system_prompts={
82
  class Generate(BaseModel):
83
  text:str
84
 
 
 
 
 
85
  def generate_text(messages: List[str], npc:str):
86
  print(npc)
87
  if npc in main_npcs:
@@ -123,6 +132,24 @@ app.add_middleware(
123
  allow_headers=["*"],
124
  )
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  @app.get("/", tags=["Home"])
127
  def api_home():
128
  return {'detail': 'Everchanging Quest backend, nothing to see here'}
@@ -131,6 +158,10 @@ def api_home():
131
  def inference(message: Message):
132
  return generate_text(messages=message.messages, npc=message.npc)
133
 
 
 
 
 
134
  #Dummy function for now
135
  def determine_vocie_from_npc(npc,genre):
136
  if npc in main_npcs:
@@ -142,7 +173,17 @@ def determine_vocie_from_npc(npc,genre):
142
  return"./voices/default_female.mp3"
143
  else:
144
  return "./voices/narator_out.wav"
145
-
 
 
 
 
 
 
 
 
 
 
146
 
147
  @app.post("/generate_wav")
148
  async def generate_wav(message: VoiceMessage):
@@ -234,15 +275,15 @@ async def generate_song():
234
  infos=get_audio_information(f"{data[0]['id']},{data[1]['id']}")
235
  return infos
236
 
237
- @app.post('/generate_image')
238
- def Imagen(image:ImageGen=None):
239
- pil_image =generate_image(image.prompt)
240
-
241
-
242
- # Convert the PIL Image to bytes
243
- img_byte_arr = BytesIO()
244
- pil_image.save(img_byte_arr, format='PNG')
245
- img_byte_arr = img_byte_arr.getvalue()
246
-
247
  # Return the image as a PNG response
248
- return Response(content=img_byte_arr, media_type="image/png")
 
9
  from langchain.chains import LLMChain
10
  from langchain.prompts import PromptTemplate
11
  from TextGen.suno import custom_generate_audio, get_audio_information,generate_lyrics
12
+ #from TextGen.diffusion import generate_image
13
  #from coqui import predict
14
  from langchain_google_genai import (
15
  ChatGoogleGenerativeAI,
 
73
  "Herbalist":"./voices/female.mp3",
74
  "Bard":"./voices/Bard_voice.mp3"
75
  }
76
+ main_npcs_elevenlabs={
77
+ "Blacksmith":"",
78
+ "Herbalist":"",
79
+ "Bard":""
80
+ }
81
  main_npc_system_prompts={
82
  "Blacksmith":"You are a blacksmith in a video game",
83
  "Herbalist":"You are an herbalist in a video game",
 
87
  class Generate(BaseModel):
88
  text:str
89
 
90
+ class Invoke(BaseModel):
91
+ system_prompt:str
92
+ message:str
93
+
94
  def generate_text(messages: List[str], npc:str):
95
  print(npc)
96
  if npc in main_npcs:
 
132
  allow_headers=["*"],
133
  )
134
 
135
+ def inference_model(system_messsage, prompt):
136
+
137
+ new_messages=[{"role": "user", "content": system_messsage},{"role": "user", "content": prompt}]
138
+ llm = ChatGoogleGenerativeAI(
139
+ model="gemini-1.5-pro-latest",
140
+ max_output_tokens=100,
141
+ temperature=1,
142
+ safety_settings={
143
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
144
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
145
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
146
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
147
+ },
148
+ )
149
+ llm_response = llm.invoke(new_messages)
150
+ print(llm_response)
151
+ return Generate(text=llm_response.content)
152
+
153
  @app.get("/", tags=["Home"])
154
  def api_home():
155
  return {'detail': 'Everchanging Quest backend, nothing to see here'}
 
158
  def inference(message: Message):
159
  return generate_text(messages=message.messages, npc=message.npc)
160
 
161
+ @app.post("/invoke_model", response_model=Generate)
162
+ def story(prompt: Invoke):
163
+ return inference_model(system_messsage=prompt.system_prompt,prompt=prompt.message)
164
+
165
  #Dummy function for now
166
  def determine_vocie_from_npc(npc,genre):
167
  if npc in main_npcs:
 
173
  return"./voices/default_female.mp3"
174
  else:
175
  return "./voices/narator_out.wav"
176
+ #Dummy function for now
177
+ def determine_elevenLav_voice_from_npc(npc,genre):
178
+ if npc in main_npcs:
179
+ return main_npcs[npc]
180
+ else:
181
+ if genre =="Male":
182
+ "./voices/default_male.mp3"
183
+ if genre=="Female":
184
+ return"./voices/default_female.mp3"
185
+ else:
186
+ return "./voices/narator_out.wav"
187
 
188
  @app.post("/generate_wav")
189
  async def generate_wav(message: VoiceMessage):
 
275
  infos=get_audio_information(f"{data[0]['id']},{data[1]['id']}")
276
  return infos
277
 
278
+ #@app.post('/generate_image')
279
+ #def Imagen(image:ImageGen=None):
280
+ # pil_image =generate_image(image.prompt)
281
+ #
282
+ #
283
+ # # Convert the PIL Image to bytes
284
+ # img_byte_arr = BytesIO()
285
+ # pil_image.save(img_byte_arr, format='PNG')
286
+ # img_byte_arr = img_byte_arr.getvalue()
287
+ #
288
  # Return the image as a PNG response
289
+ # return Response(content=img_byte_arr, media_type="image/png")