Spaces:
Runtime error
Runtime error
File size: 5,146 Bytes
1841e7e f3c8871 8cb514e f3c8871 23a1fa8 f3c8871 23a1fa8 f3c8871 1841e7e f3c8871 8cb514e f3c8871 9c43365 f3c8871 e315388 f79a327 f3c8871 1841e7e 105edfc 1841e7e f6568b9 1841e7e 863891d 1841e7e 7084404 1841e7e a7bd912 1841e7e a7bd912 1841e7e 27feb48 1841e7e b60483e 7c38857 1568509 7c38857 9d8ffa1 13ac208 7c38857 105edfc 1841e7e 105edfc 9466441 105edfc a7bd912 105edfc a7bd912 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import openai
import streamlit as st
from streamlit_chat import message
from langchain_core.messages import SystemMessage
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from langgraph.graph import MessageGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.messages import HumanMessage
from typing import List
import os
import uuid
template = """Your job is to get information from a user about their profession. We are aiming to generate a profile later
You should get the following information from them:
- Job
- Company
- tools for example for a software engineer(which frameworks/languages)
If you are not able to discerne this info, ask them to clarify! Do not attempt to wildly guess.
If you're asking anything please be friendly and comment on any of the info you have found e.g working at x company must have been a thrilling challenge
Ask one question at a time
After you are able to discerne all the information, call the relevant tool"""
OPENAI_API_KEY='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'
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
llm = ChatOpenAI(temperature=0)
def get_messages_info(messages):
return [SystemMessage(content=template)] + messages
class PromptInstructions(BaseModel):
"""Instructions on how to prompt the LLM."""
job: str
company: str
technologies: List[str]
hobies: List[str]
llm_with_tool = llm.bind_tools([PromptInstructions])
chain = get_messages_info | llm_with_tool
# Helper function for determining if tool was called
def _is_tool_call(msg):
return hasattr(msg, "additional_kwargs") and 'tool_calls' in msg.additional_kwargs
# New system prompt
prompt_system = """Based on the following context, write a good professional profile. Infer the soft skills:
{reqs}"""
# Function to get the messages for the profile
# Will only get messages AFTER the tool call
def get_profile_messages(messages):
tool_call = None
other_msgs = []
for m in messages:
if _is_tool_call(m):
tool_call = m.additional_kwargs['tool_calls'][0]['function']['arguments']
elif tool_call is not None:
other_msgs.append(m)
return [SystemMessage(content=prompt_system.format(reqs=tool_call))] + other_msgs
profile_gen_chain = get_profile_messages | llm
def get_state(messages):
if _is_tool_call(messages[-1]):
return "profile"
elif not isinstance(messages[-1], HumanMessage):
return END
for m in messages:
if _is_tool_call(m):
return "profile"
return "info"
@st.cache_resource
def get_graph():
memory = SqliteSaver.from_conn_string(":memory:")
nodes = {k:k for k in ['info', 'profile', END]}
workflow = MessageGraph()
workflow.add_node("info", chain)
workflow.add_node("profile", profile_gen_chain)
workflow.add_conditional_edges("info", get_state, nodes)
workflow.add_conditional_edges("profile", get_state, nodes)
workflow.set_entry_point("info")
graph = workflow.compile(checkpointer=memory)
return graph
graph = get_graph()
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
# Streamlit app layout
st.title("JobEasy AI")
clear_button = st.sidebar.button("Clear Conversation", key="clear")
# Initialise session state variables
if 'generated' not in st.session_state:
st.session_state['generated'] = ['Please tell me about your most recent career']
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'messages' not in st.session_state:
st.session_state['messages'] = []
# reset everything
if clear_button:
st.session_state['generated'] = ['Please tell me about your most recent career']
st.session_state['past'] = []
st.session_state['messages'] = []
# container for chat history
response_container = st.container()
# container for text box
container = st.container()
@st.cache_resource
def query(payload):
for output in graph.stream([HumanMessage(content=payload)], config=config):
if "__end__" in output:
continue
# stream() yields dictionaries with output keyed by node name
for key, value in output.items():
st.session_state['messages'].append({"role": "assistant", "content": value.content})
st.session_state['past'].append(user_input)
st.session_state['generated'].append(value.content)
with container:
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_area("You:", key='input', height=100)
submit_button = st.form_submit_button(label='Send')
if submit_button and user_input:
query(user_input)
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["generated"][i], key=str(i))
if len(st.session_state["past"]) > 0 and i < len(st.session_state["past"]):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
|