ageraustine commited on
Commit
f3c8871
1 Parent(s): b1dbc84

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.messages import SystemMessage
2
+ from langchain_openai import ChatOpenAI
3
+ from langchain_core.pydantic_v1 import BaseModel, Field
4
+ from langgraph.graph import MessageGraph, END
5
+ from langgraph.checkpoint.sqlite import SqliteSaver
6
+ from langchain_core.messages import HumanMessage
7
+ import streamlit as st
8
+ from typing import List
9
+ import os
10
+ import uuid
11
+
12
+ template = """Your job is to get information from a user about their profession. We are aiming to generate a profile later
13
+
14
+ You should get the following information from them:
15
+
16
+ - Their most recent job
17
+ - Company
18
+ - tools for example for a software engineer(which frameworks/languages)
19
+
20
+ If you are not able to discerne this info, ask them to clarify! Do not attempt to wildly guess.
21
+ If you're asking anything please be friendly and comment on any of the info you have found e.g working at x company must have been a thrilling challenge
22
+
23
+ After you are able to discerne all the information, call the relevant tool"""
24
+
25
+ llm = ChatOpenAI(temperature=0)
26
+
27
+ def get_messages_info(messages):
28
+ return [SystemMessage(content=template)] + messages
29
+
30
+
31
+ class PromptInstructions(BaseModel):
32
+ """Instructions on how to prompt the LLM."""
33
+ job: str
34
+ company: str
35
+ tools: List[str]
36
+ skills: List[str]
37
+
38
+ llm_with_tool = llm.bind_tools([PromptInstructions])
39
+
40
+ chain = get_messages_info | llm_with_tool
41
+
42
+ # Helper function for determining if tool was called
43
+ def _is_tool_call(msg):
44
+ return hasattr(msg, "additional_kwargs") and 'tool_calls' in msg.additional_kwargs
45
+
46
+
47
+ # New system prompt
48
+ prompt_system = """Based on the following context, write a good professional profile. Infer the soft skills:
49
+
50
+ {reqs}"""
51
+
52
+ # Function to get the messages for the profile
53
+ # Will only get messages AFTER the tool call
54
+ def get_profile_messages(messages):
55
+ tool_call = None
56
+ other_msgs = []
57
+ for m in messages:
58
+ if _is_tool_call(m):
59
+ tool_call = m.additional_kwargs['tool_calls'][0]['function']['arguments']
60
+ elif tool_call is not None:
61
+ other_msgs.append(m)
62
+ return [SystemMessage(content=prompt_system.format(reqs=tool_call))] + other_msgs
63
+
64
+ profile_gen_chain = get_profile_messages | llm
65
+
66
+ def get_state(messages):
67
+ if _is_tool_call(messages[-1]):
68
+ return "profile"
69
+ elif not isinstance(messages[-1], HumanMessage):
70
+ return END
71
+ for m in messages:
72
+ if _is_tool_call(m):
73
+ return "profile"
74
+ return "info"
75
+
76
+
77
+
78
+ memory = SqliteSaver.from_conn_string(":memory:")
79
+
80
+ nodes = {k:k for k in ['info', 'profile', END]}
81
+ workflow = MessageGraph()
82
+ workflow.add_node("info", chain)
83
+ workflow.add_node("profile", profile_gen_chain)
84
+ workflow.add_conditional_edges("info", get_state, nodes)
85
+ workflow.add_conditional_edges("profile", get_state, nodes)
86
+ workflow.set_entry_point("info")
87
+ graph = workflow.compile(checkpointer=memory)
88
+
89
+ config = {"configurable": {"thread_id": str(uuid.uuid4())}}
90
+
91
+ # Helper function to execute the LangChain logic
92
+ def execute_langchain(user_input):
93
+ output_list = []
94
+ for output in graph.stream([HumanMessage(content=user_input)], config=config):
95
+ if "__end__" in output:
96
+ continue
97
+ for key, value in output.items():
98
+ output_list.append((key, value))
99
+ return output_list
100
+
101
+ # Streamlit app layout
102
+ st.title("LangChain Chat")
103
+
104
+ user_input = st.text_input("You:", "")
105
+
106
+ if st.button("Send"):
107
+ if user_input:
108
+ # Execute LangChain logic
109
+ outputs = execute_langchain(user_input)
110
+
111
+ # Display LangChain outputs
112
+ for key, value in outputs:
113
+ if key == "Output from node 'info':":
114
+ st.text("Bot: " + value)
115
+ elif key == "Output from node 'profile':":
116
+ st.text("Bot: " + value)
117
+
118
+ st.text("\n---\n")
119
+
120
+ # Allow the user to quit the chat
121
+ if st.button("Quit"):
122
+ st.text("Bot: Byebye")
123
+