Jofthomas HF staff commited on
Commit
9198ac8
1 Parent(s): 676cc47

Update TextGen/router.py

Browse files
Files changed (1) hide show
  1. TextGen/router.py +18 -15
TextGen/router.py CHANGED
@@ -39,14 +39,25 @@ main_npcs={
39
  "Herbalist":"./voices/female.mp3",
40
  "Bard":"./voices/Bard_voice.mp3"
41
  }
 
 
 
 
 
42
  class Generate(BaseModel):
43
  text:str
44
 
45
- def generate_text(messages: List[str]):
46
- print(messages)
47
- prompt=messages[-1]
48
- prompt = PromptTemplate(template=prompt, input_variables=['Prompt'])
49
-
 
 
 
 
 
 
50
  # Initialize the LLM
51
  llm = ChatGoogleGenerativeAI(
52
  model="gemini-pro",
@@ -54,17 +65,9 @@ def generate_text(messages: List[str]):
54
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
55
  },
56
  )
57
-
58
- llmchain = LLMChain(
59
- prompt=prompt,
60
- llm=llm
61
- )
62
-
63
- llm_response = llmchain.run({"Prompt": prompt})
64
  return Generate(text=llm_response)
65
 
66
-
67
-
68
  app.add_middleware(
69
  CORSMiddleware,
70
  allow_origins=["*"],
@@ -79,7 +82,7 @@ def api_home():
79
 
80
  @app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
81
  def inference(message: Message):
82
- return generate_text(messages=message.messages)
83
 
84
  #Dummy function for now
85
  def determine_vocie_from_npc(npc,genre):
 
39
  "Herbalist":"./voices/female.mp3",
40
  "Bard":"./voices/Bard_voice.mp3"
41
  }
42
+ main_npc_system_prompts={
43
+ "Blacksmith":"You are a blacksmith in a video game",
44
+ "Herbalist":"You are an herbalist in a video game",
45
+ "Bard":"You are a bard in a video game"
46
+ }
47
  class Generate(BaseModel):
48
  text:str
49
 
50
+ def generate_text(messages: List[str], npc:str):
51
+ print(npc)
52
+ system_prompt=main_npc_system_prompts[npc]
53
+ print(system_prompt)
54
+ new_messages=[{"role": "system", "content": system_prompt}]
55
+ for index, message in enumerate(messages):
56
+ if index%2==0:
57
+ new_messages.append({"role": "user", "content": message})
58
+ else:
59
+ new_messages.append({"role": "assistant", "content": message})
60
+ print(new_messages)
61
  # Initialize the LLM
62
  llm = ChatGoogleGenerativeAI(
63
  model="gemini-pro",
 
65
  HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
66
  },
67
  )
68
+ llm_response = llm.invoke(new_messages)
 
 
 
 
 
 
69
  return Generate(text=llm_response)
70
 
 
 
71
  app.add_middleware(
72
  CORSMiddleware,
73
  allow_origins=["*"],
 
82
 
83
  @app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
84
  def inference(message: Message):
85
+ return generate_text(messages=message.messages, npc=message.npc)
86
 
87
  #Dummy function for now
88
  def determine_vocie_from_npc(npc,genre):