likhithv commited on
Commit
b987573
1 Parent(s): a460188

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +80 -0
  2. dspy_inference.py +131 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from groq import Groq
3
+ from typing import List, Optional
4
+ from dotenv import load_dotenv
5
+ import json, os
6
+ from pydantic import BaseModel
7
+ from dspy_inference import get_expanded_query_and_topic
8
+
9
+ load_dotenv()
10
+
11
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
12
+ USER_AVATAR = "👤"
13
+ BOT_AVATAR = "🤖"
14
+
15
+ if "messages" not in st.session_state:
16
+ st.session_state.messages = [{"role": "assistant", "content": "Hi, How can I help you today?"}]
17
+ if "conversation_state" not in st.session_state:
18
+ st.session_state["conversation_state"] = [{"role": "assistant", "content": "Hi, How can I help you today?"}]
19
+
20
+ def main():
21
+ st.title("Assignment")
22
+
23
+ for message in st.session_state.messages:
24
+ image = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
25
+ with st.chat_message(message["role"], avatar=image):
26
+ st.markdown(message["content"])
27
+
28
+ system_prompt = f'''You are a helpful assistant who can answer any question that the user asks.
29
+ '''
30
+ if prompt := st.chat_input("User input"):
31
+ st.chat_message("user", avatar=USER_AVATAR).markdown(prompt)
32
+ st.session_state.messages.append({"role": "user", "content": prompt})
33
+ conversation_context = st.session_state["conversation_state"]
34
+ conversation_context.append({"role": "user", "content": prompt})
35
+
36
+ # Use dspy to expand the query and get the topic
37
+ expanded_query = get_expanded_query_and_topic(prompt, conversation_context)
38
+
39
+ context = []
40
+ context.append({"role": "system", "content": system_prompt})
41
+ context.extend(st.session_state["conversation_state"])
42
+
43
+ # Add the expanded query to the context
44
+ if expanded_query.expand != "None":
45
+ context.append({"role": "system", "content": f"Expanded query: {expanded_query.expand}"})
46
+ context.append({"role": "system", "content": f"Topic: {expanded_query.topic}"})
47
+
48
+ response = client.chat.completions.create(
49
+ messages=context,
50
+ model="llama3-70b-8192",
51
+ temperature=0,
52
+ top_p=1,
53
+ stop=None,
54
+ stream=True,
55
+ )
56
+
57
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
58
+ result = ""
59
+ res_box = st.empty()
60
+ for chunk in response:
61
+ if chunk.choices[0].delta.content:
62
+ new_content = chunk.choices[0].delta.content
63
+ result += new_content
64
+ res_box.markdown(f'{result}')
65
+
66
+ # Display expanded question and tags separately
67
+ st.markdown("---")
68
+ # st.markdown("**Query Analysis:**")
69
+ if expanded_query.expand != "None":
70
+ st.markdown(f"**Expanded Question:** {expanded_query.expand}")
71
+ else:
72
+ st.markdown("**Expanded Question:** No expansion needed")
73
+ st.markdown(f"**Topic:** {expanded_query.topic}")
74
+
75
+ assistant_response = result
76
+ st.session_state.messages.append({"role": "assistant", "content": assistant_response})
77
+ conversation_context.append({"role": "assistant", "content": assistant_response})
78
+
79
+ if __name__ == '__main__':
80
+ main()
dspy_inference.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import dspy
3
+ import copy
4
+ import dspy.evaluate
5
+ from pydantic import BaseModel
6
+ from dotenv import load_dotenv
7
+ import os
8
+ from dspy.teleprompt import BootstrapFewShotWithRandomSearch
9
+
10
+ load_dotenv()
11
+
12
+ class Agent(dspy.Module):
13
+ """
14
+ Base Agent Module
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ model: Optional[str] | None = "llama3",
20
+ client: Optional[str] | None = "ollama",
21
+ max_tokens: Optional[int] | None = 4096,
22
+ temperature: Optional[float] | None = 0.5,
23
+ ) -> None:
24
+ """
25
+ Initialising Agent Module
26
+
27
+ Args:
28
+ model: str -> default = llama3
29
+ client: str -> default = ollama
30
+ max_tokens: int -> default = 4096
31
+ temperature: float -> default = 0.5
32
+ tools: List[Tool] -> default = None
33
+ """
34
+
35
+ self.model = dspy.GROQ(
36
+ model="llama3-8b-8192",
37
+ temperature=temperature,
38
+ api_key=os.getenv("GROQ_API_KEY"),
39
+ max_tokens=max_tokens,
40
+ frequency_penalty=1.5,
41
+ presence_penalty=1.5,
42
+ )
43
+
44
+ dspy.settings.configure(
45
+ lm=self.model,
46
+ max_tokens = max_tokens,
47
+ temperature = temperature
48
+ )
49
+
50
+ def __deepcopy__(self, memo):
51
+ new_instance = self.__class__.__new__(self.__class__)
52
+ memo[id(self)] = new_instance
53
+ for k, v in self.__dict__.items():
54
+ if k != 'model':
55
+ setattr(new_instance, k, copy.deepcopy(v, memo))
56
+ new_instance.model = self.model
57
+ return new_instance
58
+
59
+ class OutputFormat(BaseModel):
60
+ expand: Optional[str]
61
+ topic: str
62
+
63
+ class Conversation(BaseModel):
64
+ role: str
65
+ content: str
66
+
67
+ class Memory(BaseModel):
68
+ conversations: List[Conversation]
69
+
70
+ class BaseSignature(dspy.Signature):
71
+ """
72
+ You are an expert in expanding the user question and generating suitable tags for the question.
73
+ Follow the exact instructions given:
74
+ 1. Expand with only single question.
75
+ 2. Try to keep the actual content in the expand question. Example: User question: What is math ?, expand: What is mathematics ?
76
+ 3. Tags should be 2-level hierarchy topics. Eg - India - Politics, Sports- Football. Tags should be as specific as possible. If it is a general question topic: GENERAL
77
+ 4. Do not give the reference of the previous question in the expanded question.
78
+ 5. If there is no expanded version of the user question, then give it as expand = "None"
79
+ 6. If there is a general question asked, do not expand the question, just give it as expand="None"
80
+ 7. topic can not be "None"
81
+ 8. Use the provided memory to understand context and provide more relevant expansions and topics.
82
+ """
83
+
84
+ query: str = dspy.InputField(prefix = "Question: ")
85
+ memory: Memory = dspy.InputField(prefix = "Previous conversations: ", desc="This is a list of previous conversations.")
86
+ output: OutputFormat = dspy.OutputField(desc='''Expanded user question and tags are generated as output. Respond with a single JSON object. JSON Schema: {"properties": {"expand": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Expand"}, "topic": {"title": "Topic", "type": "string"}}, "required": ["expand", "topic"], "title": "OutputFormat", "type": "object"}''')
87
+
88
+ class OutputAgent(Agent):
89
+ """
90
+ Multi-output Agent Module. Inherited from Agent Module
91
+ """
92
+
93
+ def __init__(self, model: str | None = "llama3", client: str | None = "ollama", max_tokens: int | None = 8192) -> None:
94
+ super().__init__(
95
+ model = model,
96
+ client = client,
97
+ max_tokens = max_tokens
98
+ )
99
+
100
+ def __call__(self, query: str, memory: List[dict]) -> dspy.Prediction:
101
+ """
102
+ This function expands the user question and generates the tags for the user question.
103
+
104
+ Args:
105
+ query: str -> The current user query
106
+ memory: List[dict] -> List of previous conversations
107
+
108
+ Returns:
109
+ dspy.Prediction: Expanded question and topic
110
+ """
111
+
112
+ # Convert the memory list to the Memory model
113
+ conversations = [Conversation(role=m["role"], content=m["content"]) for m in memory]
114
+ memory_model = Memory(conversations=conversations)
115
+
116
+ # modules
117
+ outputGenerator = dspy.TypedPredictor(BaseSignature)
118
+
119
+ # infer
120
+ try:
121
+ output = outputGenerator(query=query, memory=memory_model)
122
+ return output
123
+ except Exception as e:
124
+ print("Retrying...", e)
125
+ return self.__call__(query=query, memory=memory)
126
+
127
+ # This function can be called from app.py to get the expanded question and topic
128
+ def get_expanded_query_and_topic(query: str, conversation_context: List[dict]):
129
+ agent = OutputAgent()
130
+ result = agent(query, conversation_context)
131
+ return result.output
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ dspy-ai
2
+ groq
3
+ streamlit