|
import streamlit as st |
|
from groq import Groq |
|
from typing import List, Optional |
|
from dotenv import load_dotenv |
|
import json, os |
|
from pydantic import BaseModel |
|
from dspy_inference import get_expanded_query_and_topic |
|
|
|
load_dotenv() |
|
|
|
client = Groq(api_key=os.getenv("GROQ_API_KEY")) |
|
USER_AVATAR = "π€" |
|
BOT_AVATAR = "π€" |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [{"role": "assistant", "content": "Hi, How can I help you today?"}] |
|
if "conversation_state" not in st.session_state: |
|
st.session_state["conversation_state"] = [{"role": "assistant", "content": "Hi, How can I help you today?"}] |
|
|
|
def main(): |
|
st.title("Query expansion and tagging") |
|
|
|
for message in st.session_state.messages: |
|
image = USER_AVATAR if message["role"] == "user" else BOT_AVATAR |
|
with st.chat_message(message["role"], avatar=image): |
|
st.markdown(message["content"]) |
|
|
|
system_prompt = f'''You are a helpful assistant who can answer any question that the user asks. |
|
''' |
|
if prompt := st.chat_input("User input"): |
|
st.chat_message("user", avatar=USER_AVATAR).markdown(prompt) |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
conversation_context = st.session_state["conversation_state"] |
|
conversation_context.append({"role": "user", "content": prompt}) |
|
|
|
|
|
expanded_query = get_expanded_query_and_topic(prompt, conversation_context) |
|
|
|
context = [] |
|
context.append({"role": "system", "content": system_prompt}) |
|
context.extend(st.session_state["conversation_state"]) |
|
|
|
|
|
if expanded_query.expand != "None": |
|
context.append({"role": "system", "content": f"Expanded query: {expanded_query.expand}"}) |
|
context.append({"role": "system", "content": f"Topic: {expanded_query.topic}"}) |
|
|
|
response = client.chat.completions.create( |
|
messages=context, |
|
model="llama-3.1-405b-reasoning", |
|
temperature=0, |
|
top_p=1, |
|
stop=None, |
|
stream=True, |
|
) |
|
|
|
with st.chat_message("assistant", avatar=BOT_AVATAR): |
|
result = "" |
|
res_box = st.empty() |
|
for chunk in response: |
|
if chunk.choices[0].delta.content: |
|
new_content = chunk.choices[0].delta.content |
|
result += new_content |
|
res_box.markdown(f'{result}') |
|
|
|
|
|
st.markdown("---") |
|
|
|
if expanded_query.expand != "None": |
|
st.markdown(f"**Expanded Question:** {expanded_query.expand}") |
|
else: |
|
st.markdown("**Expanded Question:** No expansion needed") |
|
st.markdown(f"**Topic:** {expanded_query.topic}") |
|
|
|
assistant_response = result |
|
st.session_state.messages.append({"role": "assistant", "content": assistant_response}) |
|
conversation_context.append({"role": "assistant", "content": assistant_response}) |
|
|
|
if __name__ == '__main__': |
|
main() |