Spaces:
Runtime error
Runtime error
ageraustine
commited on
Commit
•
f3c8871
1
Parent(s):
b1dbc84
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.messages import SystemMessage
|
2 |
+
from langchain_openai import ChatOpenAI
|
3 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
4 |
+
from langgraph.graph import MessageGraph, END
|
5 |
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
6 |
+
from langchain_core.messages import HumanMessage
|
7 |
+
import streamlit as st
|
8 |
+
from typing import List
|
9 |
+
import os
|
10 |
+
import uuid
|
11 |
+
|
12 |
+
template = """Your job is to get information from a user about their profession. We are aiming to generate a profile later
|
13 |
+
|
14 |
+
You should get the following information from them:
|
15 |
+
|
16 |
+
- Their most recent job
|
17 |
+
- Company
|
18 |
+
- tools for example for a software engineer(which frameworks/languages)
|
19 |
+
|
20 |
+
If you are not able to discerne this info, ask them to clarify! Do not attempt to wildly guess.
|
21 |
+
If you're asking anything please be friendly and comment on any of the info you have found e.g working at x company must have been a thrilling challenge
|
22 |
+
|
23 |
+
After you are able to discerne all the information, call the relevant tool"""
|
24 |
+
|
25 |
+
llm = ChatOpenAI(temperature=0)
|
26 |
+
|
27 |
+
def get_messages_info(messages):
|
28 |
+
return [SystemMessage(content=template)] + messages
|
29 |
+
|
30 |
+
|
31 |
+
class PromptInstructions(BaseModel):
|
32 |
+
"""Instructions on how to prompt the LLM."""
|
33 |
+
job: str
|
34 |
+
company: str
|
35 |
+
tools: List[str]
|
36 |
+
skills: List[str]
|
37 |
+
|
38 |
+
llm_with_tool = llm.bind_tools([PromptInstructions])
|
39 |
+
|
40 |
+
chain = get_messages_info | llm_with_tool
|
41 |
+
|
42 |
+
# Helper function for determining if tool was called
|
43 |
+
def _is_tool_call(msg):
|
44 |
+
return hasattr(msg, "additional_kwargs") and 'tool_calls' in msg.additional_kwargs
|
45 |
+
|
46 |
+
|
47 |
+
# New system prompt
|
48 |
+
prompt_system = """Based on the following context, write a good professional profile. Infer the soft skills:
|
49 |
+
|
50 |
+
{reqs}"""
|
51 |
+
|
52 |
+
# Function to get the messages for the profile
|
53 |
+
# Will only get messages AFTER the tool call
|
54 |
+
def get_profile_messages(messages):
|
55 |
+
tool_call = None
|
56 |
+
other_msgs = []
|
57 |
+
for m in messages:
|
58 |
+
if _is_tool_call(m):
|
59 |
+
tool_call = m.additional_kwargs['tool_calls'][0]['function']['arguments']
|
60 |
+
elif tool_call is not None:
|
61 |
+
other_msgs.append(m)
|
62 |
+
return [SystemMessage(content=prompt_system.format(reqs=tool_call))] + other_msgs
|
63 |
+
|
64 |
+
profile_gen_chain = get_profile_messages | llm
|
65 |
+
|
66 |
+
def get_state(messages):
|
67 |
+
if _is_tool_call(messages[-1]):
|
68 |
+
return "profile"
|
69 |
+
elif not isinstance(messages[-1], HumanMessage):
|
70 |
+
return END
|
71 |
+
for m in messages:
|
72 |
+
if _is_tool_call(m):
|
73 |
+
return "profile"
|
74 |
+
return "info"
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
memory = SqliteSaver.from_conn_string(":memory:")
|
79 |
+
|
80 |
+
nodes = {k:k for k in ['info', 'profile', END]}
|
81 |
+
workflow = MessageGraph()
|
82 |
+
workflow.add_node("info", chain)
|
83 |
+
workflow.add_node("profile", profile_gen_chain)
|
84 |
+
workflow.add_conditional_edges("info", get_state, nodes)
|
85 |
+
workflow.add_conditional_edges("profile", get_state, nodes)
|
86 |
+
workflow.set_entry_point("info")
|
87 |
+
graph = workflow.compile(checkpointer=memory)
|
88 |
+
|
89 |
+
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
|
90 |
+
|
91 |
+
# Helper function to execute the LangChain logic
|
92 |
+
def execute_langchain(user_input):
|
93 |
+
output_list = []
|
94 |
+
for output in graph.stream([HumanMessage(content=user_input)], config=config):
|
95 |
+
if "__end__" in output:
|
96 |
+
continue
|
97 |
+
for key, value in output.items():
|
98 |
+
output_list.append((key, value))
|
99 |
+
return output_list
|
100 |
+
|
101 |
+
# Streamlit app layout
|
102 |
+
st.title("LangChain Chat")
|
103 |
+
|
104 |
+
user_input = st.text_input("You:", "")
|
105 |
+
|
106 |
+
if st.button("Send"):
|
107 |
+
if user_input:
|
108 |
+
# Execute LangChain logic
|
109 |
+
outputs = execute_langchain(user_input)
|
110 |
+
|
111 |
+
# Display LangChain outputs
|
112 |
+
for key, value in outputs:
|
113 |
+
if key == "Output from node 'info':":
|
114 |
+
st.text("Bot: " + value)
|
115 |
+
elif key == "Output from node 'profile':":
|
116 |
+
st.text("Bot: " + value)
|
117 |
+
|
118 |
+
st.text("\n---\n")
|
119 |
+
|
120 |
+
# Allow the user to quit the chat
|
121 |
+
if st.button("Quit"):
|
122 |
+
st.text("Bot: Byebye")
|
123 |
+
|