Spaces:
Running
Running
pseudotensor
commited on
Commit
•
e93eb3d
1
Parent(s):
6ef8710
open-strawberry
Browse files- app.py +350 -0
- cli.py +107 -0
- models.py +483 -0
- open_strawberry.py +491 -0
- requirements.txt +13 -0
- utils.py +57 -0
app.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import openai
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
import time
|
7 |
+
try:
|
8 |
+
from src.models import get_all_model_names
|
9 |
+
from src.open_strawberry import get_defaults, manage_conversation
|
10 |
+
except (ModuleNotFoundError, ImportError):
|
11 |
+
from models import get_all_model_names
|
12 |
+
from open_strawberry import get_defaults, manage_conversation
|
13 |
+
|
14 |
+
(model, system_prompt, initial_prompt, expected_answer,
|
15 |
+
next_prompts, num_turns, show_next, final_prompt,
|
16 |
+
temperature, max_tokens,
|
17 |
+
num_turns_final_mod,
|
18 |
+
show_cot,
|
19 |
+
verbose) = get_defaults()
|
20 |
+
|
21 |
+
st.title("Open Strawberry Conversation")
|
22 |
+
st.markdown("[Open Strawberry GitHub Repo](https://github.com/pseudotensor/open-strawberry)")
|
23 |
+
|
24 |
+
# Initialize session state
|
25 |
+
if "messages" not in st.session_state:
|
26 |
+
st.session_state.messages = []
|
27 |
+
if "turn_count" not in st.session_state:
|
28 |
+
st.session_state.turn_count = 0
|
29 |
+
if "input_key" not in st.session_state:
|
30 |
+
st.session_state.input_key = 0
|
31 |
+
if "conversation_started" not in st.session_state:
|
32 |
+
st.session_state.conversation_started = False
|
33 |
+
if "waiting_for_continue" not in st.session_state:
|
34 |
+
st.session_state.waiting_for_continue = False
|
35 |
+
if "generator" not in st.session_state:
|
36 |
+
st.session_state.generator = None # Store the generator in session state
|
37 |
+
if "prompt" not in st.session_state:
|
38 |
+
st.session_state.prompt = None # Store the prompt in session state
|
39 |
+
if "answer" not in st.session_state:
|
40 |
+
st.session_state.answer = None
|
41 |
+
if "system_prompt" not in st.session_state:
|
42 |
+
st.session_state.system_prompt = None
|
43 |
+
if "output_tokens" not in st.session_state:
|
44 |
+
st.session_state.output_tokens = 0
|
45 |
+
if "input_tokens" not in st.session_state:
|
46 |
+
st.session_state.input_tokens = 0
|
47 |
+
if "cache_creation_input_tokens" not in st.session_state:
|
48 |
+
st.session_state.cache_creation_input_tokens = 0
|
49 |
+
if "cache_read_input_tokens" not in st.session_state:
|
50 |
+
st.session_state.cache_read_input_tokens = 0
|
51 |
+
if "verbose" not in st.session_state:
|
52 |
+
st.session_state.verbose = verbose
|
53 |
+
if "max_tokens" not in st.session_state:
|
54 |
+
st.session_state.max_tokens = max_tokens
|
55 |
+
if "temperature" not in st.session_state:
|
56 |
+
st.session_state.temperature = temperature
|
57 |
+
if "next_prompts" not in st.session_state:
|
58 |
+
st.session_state.next_prompts = next_prompts
|
59 |
+
if "final_prompt" not in st.session_state:
|
60 |
+
st.session_state.final_prompt = final_prompt
|
61 |
+
|
62 |
+
|
63 |
+
# Function to display chat messages
|
64 |
+
def display_chat():
|
65 |
+
display_step = 1
|
66 |
+
for message in st.session_state.messages:
|
67 |
+
if message["role"] == "assistant":
|
68 |
+
if 'final' in message and message['final']:
|
69 |
+
display_final(message)
|
70 |
+
elif 'turn_title' in message and message['turn_title']:
|
71 |
+
display_turn_title(message, display_step=display_step)
|
72 |
+
display_step += 1
|
73 |
+
else:
|
74 |
+
with st.expander("Chain of Thoughts", expanded=st.session_state["show_cot"]):
|
75 |
+
assistant_container1 = st.chat_message("assistant")
|
76 |
+
with assistant_container1.container():
|
77 |
+
st.markdown(message["content"].replace('\n', ' \n'), unsafe_allow_html=True)
|
78 |
+
elif message["role"] == "user":
|
79 |
+
if not message["initial"] and not st.session_state.show_next:
|
80 |
+
continue
|
81 |
+
user_container1 = st.chat_message("user")
|
82 |
+
with user_container1:
|
83 |
+
st.markdown(message["content"].replace('\n', ' \n'), unsafe_allow_html=True)
|
84 |
+
|
85 |
+
|
86 |
+
def display_final(chunk1, can_rerun=False):
|
87 |
+
if 'final' in chunk1 and chunk1['final']:
|
88 |
+
if st.session_state.answer:
|
89 |
+
if st.session_state.answer.strip() in chunk1["content"]:
|
90 |
+
st.markdown(f'<h3 class="expander-title">🏆 Final Answer</h3>', unsafe_allow_html=True)
|
91 |
+
else:
|
92 |
+
st.markdown(f'Expected: **{st.session_state.answer.strip()}**', unsafe_allow_html=True)
|
93 |
+
st.markdown(f'<h3 class="expander-title">👎 Final Answer</h3>', unsafe_allow_html=True)
|
94 |
+
else:
|
95 |
+
st.markdown(f'<h3 class="expander-title">👌 Final Answer</h3>', unsafe_allow_html=True)
|
96 |
+
final = chunk1["content"].strip().replace('\n', ' \n')
|
97 |
+
if '\n' in final or '<br>' in final:
|
98 |
+
st.markdown(f'{final}', unsafe_allow_html=True)
|
99 |
+
else:
|
100 |
+
st.markdown(f'**{final}**', unsafe_allow_html=True)
|
101 |
+
if can_rerun:
|
102 |
+
# rerun to get token stats
|
103 |
+
st.rerun()
|
104 |
+
|
105 |
+
|
106 |
+
def display_turn_title(chunk1, display_step=None):
|
107 |
+
if display_step is None:
|
108 |
+
display_step = st.session_state.turn_count
|
109 |
+
name = "Completed Step"
|
110 |
+
else:
|
111 |
+
name = "Step"
|
112 |
+
if 'turn_title' in chunk1 and chunk1['turn_title']:
|
113 |
+
turn_title = chunk1["content"].strip().replace('\n', ' \n')
|
114 |
+
step_time = f' in time {str(int(chunk1["thinking_time"]))}s'
|
115 |
+
acum_time = f' in total {str(int(chunk1["total_thinking_time"]))}s'
|
116 |
+
st.markdown(f'**{name} {display_step}: {turn_title}{step_time}{acum_time}**', unsafe_allow_html=True)
|
117 |
+
|
118 |
+
|
119 |
+
if st.button("Start Conversation", disabled=st.session_state.conversation_started):
|
120 |
+
st.session_state.conversation_started = True
|
121 |
+
|
122 |
+
# Sidebar
|
123 |
+
st.sidebar.title("Controls")
|
124 |
+
|
125 |
+
on_hf_spaces = os.getenv("HF_SPACES", '0') == '1'
|
126 |
+
|
127 |
+
|
128 |
+
def save_env_vars(env_vars):
|
129 |
+
assert not on_hf_spaces, "Cannot save env vars in HF Spaces"
|
130 |
+
env_path = os.path.join(os.path.dirname(__file__), "..", ".env")
|
131 |
+
from dotenv import set_key
|
132 |
+
for key, value in env_vars.items():
|
133 |
+
set_key(env_path, key, value)
|
134 |
+
|
135 |
+
|
136 |
+
def get_dotenv_values():
|
137 |
+
if on_hf_spaces:
|
138 |
+
return st.session_state.secrets
|
139 |
+
else:
|
140 |
+
from dotenv import dotenv_values
|
141 |
+
return dotenv_values(os.path.join(os.path.dirname(__file__), "..", ".env"))
|
142 |
+
|
143 |
+
|
144 |
+
if 'secrets' not in st.session_state:
|
145 |
+
if on_hf_spaces:
|
146 |
+
# allow user to enter
|
147 |
+
st.session_state.secrets = dict(OPENAI_API_KEY='',
|
148 |
+
OPENAI_BASE_URL='https://api.openai.com/v1',
|
149 |
+
OPENAI_MODEL_NAME='',
|
150 |
+
# OLLAMA_OPENAI_API_KEY='',
|
151 |
+
# OLLAMA_OPENAI_BASE_URL='http://localhost:11434/v1/',
|
152 |
+
# OLLAMA_OPENAI_MODEL_NAME='',
|
153 |
+
# AZURE_OPENAI_API_KEY='',
|
154 |
+
# AZURE_OPENAI_API_VERSION='',
|
155 |
+
# AZURE_OPENAI_ENDPOINT='',
|
156 |
+
# AZURE_OPENAI_DEPLOYMENT='',
|
157 |
+
# AZURE_OPENAI_MODEL_NAME='',
|
158 |
+
GEMINI_API_KEY='',
|
159 |
+
# MISTRAL_API_KEY='',
|
160 |
+
GROQ_API_KEY='',
|
161 |
+
ANTHROPIC_API_KEY='',
|
162 |
+
)
|
163 |
+
|
164 |
+
else:
|
165 |
+
st.session_state.secrets = {}
|
166 |
+
|
167 |
+
|
168 |
+
def update_model_selection():
|
169 |
+
visible_models1 = get_all_model_names(st.session_state.secrets, on_hf_spaces)
|
170 |
+
if visible_models1 and "model_name" in st.session_state:
|
171 |
+
if st.session_state.model_name not in visible_models1:
|
172 |
+
st.session_state.model_name = visible_models1[0]
|
173 |
+
|
174 |
+
|
175 |
+
# Replace the existing model selection code with this
|
176 |
+
if 'model_name' not in st.session_state or not st.session_state.model_name:
|
177 |
+
update_model_selection()
|
178 |
+
|
179 |
+
# Model selection
|
180 |
+
visible_models = get_all_model_names(st.session_state.secrets, on_hf_spaces)
|
181 |
+
st.sidebar.selectbox("Select Model", visible_models, key="model_name",
|
182 |
+
disabled=st.session_state.conversation_started)
|
183 |
+
st.sidebar.checkbox("Show Next", value=show_next, key="show_next", disabled=st.session_state.conversation_started)
|
184 |
+
st.sidebar.number_input("Num Turns to Check if Final Answer", value=num_turns_final_mod, key="num_turns_final_mod",
|
185 |
+
disabled=st.session_state.conversation_started)
|
186 |
+
st.sidebar.number_input("Num Turns per User Click of Continue", value=num_turns, key="num_turns",
|
187 |
+
disabled=st.session_state.conversation_started)
|
188 |
+
st.sidebar.checkbox("Show Chain of Thoughts Details", value=show_cot, key="show_cot",
|
189 |
+
disabled=st.session_state.conversation_started)
|
190 |
+
|
191 |
+
# Reset conversation button
|
192 |
+
reset_clicked = st.sidebar.button("Reset Conversation")
|
193 |
+
with st.sidebar.expander("Edit in-memory session secrets" if on_hf_spaces else "Edit .env", expanded=on_hf_spaces):
|
194 |
+
dotenv_dict = get_dotenv_values()
|
195 |
+
new_env = {}
|
196 |
+
for k, v in dotenv_dict.items():
|
197 |
+
new_env[k] = st.text_input(k, value=v, key=k, disabled=st.session_state.conversation_started, type="password")
|
198 |
+
st.session_state.secrets[k] = new_env[k]
|
199 |
+
save_secrets_clicked = st.button("Save dotenv" if not on_hf_spaces else "Save secrets to memory")
|
200 |
+
|
201 |
+
if save_secrets_clicked:
|
202 |
+
if on_hf_spaces:
|
203 |
+
st.success("secrets temporarily stored to your session memory only")
|
204 |
+
else:
|
205 |
+
save_env_vars(st.session_state.user_secrets)
|
206 |
+
st.success("dotenv saved to .env file")
|
207 |
+
|
208 |
+
if reset_clicked:
|
209 |
+
st.session_state.messages = []
|
210 |
+
st.session_state.turn_count = 0
|
211 |
+
st.sidebar.write(f"Turn count: {st.session_state.turn_count}")
|
212 |
+
st.session_state.input_key += 1
|
213 |
+
st.session_state.conversation_started = False
|
214 |
+
st.session_state.generator = None # Reset the generator
|
215 |
+
reset_clicked = False
|
216 |
+
st.session_state.output_tokens = 0
|
217 |
+
st.session_state.input_tokens = 0
|
218 |
+
st.session_state.cache_creation_input_tokens = 0
|
219 |
+
st.session_state.cache_read_input_tokens = 0
|
220 |
+
st.rerun()
|
221 |
+
|
222 |
+
st.session_state.waiting_for_continue = False
|
223 |
+
|
224 |
+
# Display debug information
|
225 |
+
st.sidebar.write(f"Turn count: {st.session_state.turn_count}")
|
226 |
+
num_messages = len([x for x in st.session_state.messages if x.get('role', '') == 'assistant'])
|
227 |
+
st.sidebar.write(f"Number of AI messages: {num_messages}")
|
228 |
+
st.sidebar.write(f"Conversation started: {st.session_state.conversation_started}")
|
229 |
+
st.sidebar.write(f"Output tokens: {st.session_state.output_tokens}")
|
230 |
+
st.sidebar.write(f"Input tokens: {st.session_state.input_tokens}")
|
231 |
+
st.sidebar.write(f"Cache creation input tokens: {st.session_state.cache_creation_input_tokens}")
|
232 |
+
st.sidebar.write(f"Cache read input tokens: {st.session_state.cache_read_input_tokens}")
|
233 |
+
|
234 |
+
# Handle user input
|
235 |
+
if not st.session_state.conversation_started:
|
236 |
+
prompt = st.text_area("What would you like to ask?", value=initial_prompt,
|
237 |
+
key=f"input_{st.session_state.input_key}", height=500)
|
238 |
+
st.session_state.prompt = prompt
|
239 |
+
answer = st.text_area("Expected answer (Empty if do not know)", value=expected_answer,
|
240 |
+
key=f"answer_{st.session_state.input_key}", height=100)
|
241 |
+
st.session_state.answer = answer
|
242 |
+
system_prompt = st.text_area("System Prompt", value=system_prompt,
|
243 |
+
key=f"system_prompt_{st.session_state.input_key}", height=200)
|
244 |
+
st.session_state.system_prompt = system_prompt
|
245 |
+
else:
|
246 |
+
st.session_state.conversation_started = True
|
247 |
+
st.session_state.input_key += 1
|
248 |
+
|
249 |
+
# Display chat history
|
250 |
+
chat_container = st.container()
|
251 |
+
with chat_container:
|
252 |
+
display_chat()
|
253 |
+
|
254 |
+
# Process conversation
|
255 |
+
current_assistant_message = ''
|
256 |
+
assistant_placeholder = None
|
257 |
+
|
258 |
+
try:
|
259 |
+
while True:
|
260 |
+
if st.session_state.waiting_for_continue:
|
261 |
+
time.sleep(0.1) # Short sleep to prevent excessive CPU usage
|
262 |
+
continue
|
263 |
+
if not st.session_state.conversation_started:
|
264 |
+
time.sleep(0.1)
|
265 |
+
continue
|
266 |
+
elif st.session_state.generator is None:
|
267 |
+
st.session_state.generator = manage_conversation(
|
268 |
+
model=st.session_state["model_name"],
|
269 |
+
system=st.session_state.system_prompt,
|
270 |
+
initial_prompt=st.session_state.prompt,
|
271 |
+
next_prompts=st.session_state.next_prompts,
|
272 |
+
final_prompt=st.session_state.final_prompt,
|
273 |
+
num_turns_final_mod=st.session_state.num_turns_final_mod,
|
274 |
+
num_turns=st.session_state.num_turns,
|
275 |
+
temperature=st.session_state.temperature,
|
276 |
+
max_tokens=st.session_state.max_tokens,
|
277 |
+
verbose=st.session_state.verbose,
|
278 |
+
)
|
279 |
+
chunk = next(st.session_state.generator)
|
280 |
+
if chunk["role"] == "assistant":
|
281 |
+
if not chunk.get('final', False) and not chunk.get('turn_title', False):
|
282 |
+
current_assistant_message += chunk["content"]
|
283 |
+
if assistant_placeholder is None:
|
284 |
+
assistant_placeholder = st.empty() # Placeholder for assistant's message
|
285 |
+
|
286 |
+
# Update the assistant container with the progressively streaming message
|
287 |
+
with assistant_placeholder.container():
|
288 |
+
# Update in the same chat message
|
289 |
+
with st.expander("Chain of Thoughts", expanded=st.session_state["show_cot"]):
|
290 |
+
st.chat_message("assistant").markdown(current_assistant_message, unsafe_allow_html=True)
|
291 |
+
if 'turn_title' in chunk and chunk['turn_title']:
|
292 |
+
st.session_state.messages.append(
|
293 |
+
{"role": "assistant", "content": chunk['content'], 'turn_title': True,
|
294 |
+
'thinking_time': chunk['thinking_time'],
|
295 |
+
'total_thinking_time': chunk['total_thinking_time']})
|
296 |
+
display_turn_title(chunk)
|
297 |
+
if 'final' in chunk and chunk['final']:
|
298 |
+
# user role would normally do this, but on final step needs to be here
|
299 |
+
st.session_state.messages.append(
|
300 |
+
{"role": "assistant", "content": current_assistant_message, 'final': False})
|
301 |
+
# last message, so won't reach user turn, so need to store final assistant message from parsing
|
302 |
+
st.session_state.messages.append(
|
303 |
+
{"role": "assistant", "content": chunk['content'], 'final': True})
|
304 |
+
display_final(chunk, can_rerun=True)
|
305 |
+
|
306 |
+
elif chunk["role"] == "user":
|
307 |
+
if current_assistant_message:
|
308 |
+
st.session_state.messages.append(
|
309 |
+
{"role": "assistant", "content": current_assistant_message, 'final': chunk.get('final', False)})
|
310 |
+
# Reset assistant message when user provides input
|
311 |
+
# Display user message
|
312 |
+
if not chunk["initial"] and not st.session_state.show_next:
|
313 |
+
pass
|
314 |
+
else:
|
315 |
+
user_container = st.chat_message("user")
|
316 |
+
with user_container:
|
317 |
+
st.markdown(chunk["content"].replace('\n', ' \n'), unsafe_allow_html=True)
|
318 |
+
st.session_state.messages.append({"role": "user", "content": chunk["content"], 'initial': chunk["initial"]})
|
319 |
+
|
320 |
+
st.session_state.turn_count += 1
|
321 |
+
if current_assistant_message:
|
322 |
+
assistant_placeholder = st.empty() # Reset placeholder
|
323 |
+
current_assistant_message = ""
|
324 |
+
|
325 |
+
elif chunk["role"] == "action":
|
326 |
+
if chunk["content"] in ["continue?"]:
|
327 |
+
# Continue conversation button
|
328 |
+
continue_clicked = st.button("Continue Conversation")
|
329 |
+
st.session_state.waiting_for_continue = True
|
330 |
+
st.session_state.turn_count += 1
|
331 |
+
if current_assistant_message:
|
332 |
+
st.session_state.messages.append({"role": "assistant", "content": current_assistant_message})
|
333 |
+
assistant_placeholder = st.empty() # Reset placeholder
|
334 |
+
current_assistant_message = ""
|
335 |
+
elif chunk["content"] == "end":
|
336 |
+
break
|
337 |
+
elif chunk["role"] == "usage":
|
338 |
+
st.session_state.output_tokens += chunk["content"]["output_tokens"] if "output_tokens" in chunk[
|
339 |
+
"content"] else 0
|
340 |
+
st.session_state.input_tokens += chunk["content"]["input_tokens"] if "input_tokens" in chunk[
|
341 |
+
"content"] else 0
|
342 |
+
st.session_state.cache_creation_input_tokens += chunk["content"][
|
343 |
+
"cache_creation_input_tokens"] if "cache_creation_input_tokens" in chunk["content"] else 0
|
344 |
+
st.session_state.cache_read_input_tokens += chunk["content"][
|
345 |
+
"cache_read_input_tokens"] if "cache_read_input_tokens" in chunk["content"] else 0
|
346 |
+
|
347 |
+
time.sleep(0.001) # Small delay to prevent excessive updates
|
348 |
+
|
349 |
+
except StopIteration:
|
350 |
+
pass
|
cli.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import time
|
3 |
+
|
4 |
+
from src.open_strawberry import get_defaults, manage_conversation
|
5 |
+
|
6 |
+
|
7 |
+
def parse_arguments(model, system_prompt, next_prompts, num_turns, show_next, final_prompt,
|
8 |
+
num_turns_final_mod, show_cot, verbose):
|
9 |
+
parser = argparse.ArgumentParser(description="Open Strawberry Conversation Manager")
|
10 |
+
parser.add_argument("--show_next", action="store_true", default=show_next, help="Show all messages")
|
11 |
+
parser.add_argument("--verbose", action="store_true", default=verbose, help="Show usage information")
|
12 |
+
parser.add_argument("--system_prompt", type=str, default=system_prompt, help="Custom system prompt")
|
13 |
+
parser.add_argument("--num_turns_final_mod", type=int, default=num_turns_final_mod,
|
14 |
+
help="Number of turns before final prompt")
|
15 |
+
parser.add_argument("--num_turns", type=int, default=num_turns,
|
16 |
+
help="Number of turns before pausing for continuation")
|
17 |
+
parser.add_argument("--model", type=str, default=model, help="Model to use for conversation")
|
18 |
+
parser.add_argument("--initial_prompt", type=str, default='', help="Initial prompt. If empty, then ask user.")
|
19 |
+
parser.add_argument("--expected_answer", type=str, default='', help="Expected answer. If empty, then ignore.")
|
20 |
+
parser.add_argument("--next_prompts", type=str, nargs="+", default=next_prompts, help="Next prompts")
|
21 |
+
parser.add_argument("--final_prompt", type=str, default=final_prompt, help="Final prompt")
|
22 |
+
parser.add_argument("--temperature", type=float, default=0.3, help="Temperature for the model")
|
23 |
+
parser.add_argument("--max_tokens", type=int, default=1024, help="Maximum tokens for the model")
|
24 |
+
parser.add_argument("--seed", type=int, default=0, help="Random seed, 0 means random seed")
|
25 |
+
parser.add_argument("--show_cot", type=bool, default=show_cot, help="Whether to show detailed Chain of Thoughts")
|
26 |
+
|
27 |
+
return parser.parse_args()
|
28 |
+
|
29 |
+
|
30 |
+
def go_cli():
|
31 |
+
(model, system_prompt, initial_prompt, expected_answer,
|
32 |
+
next_prompts, num_turns, show_next, final_prompt,
|
33 |
+
temperature, max_tokens, num_turns_final_mod,
|
34 |
+
show_cot, verbose) = get_defaults()
|
35 |
+
args = parse_arguments(model, system_prompt, next_prompts, num_turns, show_next, final_prompt,
|
36 |
+
num_turns_final_mod, show_cot, verbose)
|
37 |
+
|
38 |
+
if args.initial_prompt == '':
|
39 |
+
initial_prompt_query = input("Enter the initial prompt (hitting enter will use default initial_prompt)\n\n")
|
40 |
+
if initial_prompt_query not in ['', '\n', '\r\n']:
|
41 |
+
initial_prompt_chosen = initial_prompt_query
|
42 |
+
else:
|
43 |
+
initial_prompt_chosen = initial_prompt
|
44 |
+
else:
|
45 |
+
initial_prompt_chosen = args.initial_prompt
|
46 |
+
|
47 |
+
generator = manage_conversation(model=args.model,
|
48 |
+
system=args.system_prompt,
|
49 |
+
initial_prompt=initial_prompt_chosen,
|
50 |
+
next_prompts=args.next_prompts,
|
51 |
+
final_prompt=args.final_prompt,
|
52 |
+
num_turns_final_mod=args.num_turns_final_mod,
|
53 |
+
num_turns=args.num_turns,
|
54 |
+
temperature=args.temperature,
|
55 |
+
max_tokens=args.max_tokens,
|
56 |
+
seed=args.seed,
|
57 |
+
cli_mode=True)
|
58 |
+
response = ''
|
59 |
+
conversation_history = []
|
60 |
+
|
61 |
+
try:
|
62 |
+
step = 1
|
63 |
+
while True:
|
64 |
+
chunk = next(generator)
|
65 |
+
if 'role' in chunk and chunk['role'] == 'assistant':
|
66 |
+
response += chunk['content']
|
67 |
+
|
68 |
+
if 'turn_title' in chunk and chunk['turn_title']:
|
69 |
+
step_time = f' in time {str(int(chunk["thinking_time"]))}s'
|
70 |
+
acum_time = f' in total {str(int(chunk["total_thinking_time"]))}s'
|
71 |
+
extra = '\n\n' if show_cot else ''
|
72 |
+
extra2 = '**' if show_cot else ''
|
73 |
+
extra3 = ' ' if show_cot else ''
|
74 |
+
print(
|
75 |
+
f'{extra}{extra2}{extra3}Completed Step {step}: {chunk["content"]}{step_time}{acum_time}{extra3}{extra2}{extra}')
|
76 |
+
step += 1
|
77 |
+
elif 'final' in chunk and chunk['final']:
|
78 |
+
if '\n' in chunk['content'] or '\r' in chunk['content']:
|
79 |
+
print(f'\n\nFinal Answer:\n\n {chunk["content"]}')
|
80 |
+
else:
|
81 |
+
print('\n\nFinal Answer:\n\n**', chunk['content'], '**\n\n')
|
82 |
+
elif show_cot:
|
83 |
+
print(chunk['content'], end='')
|
84 |
+
if 'chat_history' in chunk:
|
85 |
+
conversation_history = chunk['chat_history']
|
86 |
+
elif 'role' in chunk and chunk['role'] == 'user':
|
87 |
+
if not chunk['initial'] and not show_next:
|
88 |
+
if show_cot:
|
89 |
+
print('\n\n')
|
90 |
+
continue
|
91 |
+
print('\n', end='') # finish assistant
|
92 |
+
print('\nUser: ', chunk['content'], end='\n\n')
|
93 |
+
print('\nAssistant:\n\n ')
|
94 |
+
time.sleep(0.001)
|
95 |
+
except StopIteration as e:
|
96 |
+
pass
|
97 |
+
|
98 |
+
if verbose:
|
99 |
+
print("Conversation history:", conversation_history)
|
100 |
+
|
101 |
+
if expected_answer and expected_answer in conversation_history[-1]['content']:
|
102 |
+
print("\n\nGot Expected answer!")
|
103 |
+
|
104 |
+
if not show_cot:
|
105 |
+
print("**FULL RESPONSE:**\n\n")
|
106 |
+
print(response)
|
107 |
+
return response
|
models.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
from typing import List, Dict, Generator
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
from tenacity import (
|
8 |
+
retry,
|
9 |
+
stop_after_attempt,
|
10 |
+
wait_random_exponential,
|
11 |
+
) # for exponential backoff
|
12 |
+
|
13 |
+
# Load environment variables from .env file
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
|
17 |
+
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(3))
|
18 |
+
def anthropic_completion_with_backoff(client, *args, **kwargs):
|
19 |
+
return client.beta.prompt_caching.messages.create(*args, **kwargs)
|
20 |
+
|
21 |
+
|
22 |
+
def get_anthropic(model: str,
|
23 |
+
prompt: str,
|
24 |
+
temperature: float = 0,
|
25 |
+
max_tokens: int = 4096,
|
26 |
+
system: str = '',
|
27 |
+
chat_history: List[Dict] = None,
|
28 |
+
verbose=False) -> \
|
29 |
+
Generator[dict, None, None]:
|
30 |
+
model = model.replace('anthropic:', '')
|
31 |
+
|
32 |
+
# https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
33 |
+
import anthropic
|
34 |
+
|
35 |
+
clawd_key = os.getenv('ANTHROPIC_API_KEY')
|
36 |
+
clawd_client = anthropic.Anthropic(api_key=clawd_key) if clawd_key else None
|
37 |
+
|
38 |
+
if chat_history is None:
|
39 |
+
chat_history = []
|
40 |
+
|
41 |
+
messages = []
|
42 |
+
|
43 |
+
# Add conversation history, removing cache_control from all but the last two user messages
|
44 |
+
for i, message in enumerate(chat_history):
|
45 |
+
if message["role"] == "user":
|
46 |
+
if i >= len(chat_history) - 3: # Last two user messages
|
47 |
+
messages.append(message)
|
48 |
+
else:
|
49 |
+
messages.append({
|
50 |
+
"role": "user",
|
51 |
+
"content": [{"type": "text", "text": message["content"][0]["text"]}]
|
52 |
+
})
|
53 |
+
else:
|
54 |
+
messages.append(message)
|
55 |
+
|
56 |
+
# Add the new user message
|
57 |
+
messages.append({
|
58 |
+
"role": "user",
|
59 |
+
"content": [
|
60 |
+
{
|
61 |
+
"type": "text",
|
62 |
+
"text": prompt,
|
63 |
+
"cache_control": {"type": "ephemeral"}
|
64 |
+
}
|
65 |
+
]
|
66 |
+
})
|
67 |
+
|
68 |
+
response = anthropic_completion_with_backoff(clawd_client,
|
69 |
+
model=model,
|
70 |
+
max_tokens=max_tokens,
|
71 |
+
temperature=temperature,
|
72 |
+
system=system,
|
73 |
+
messages=messages,
|
74 |
+
stream=True
|
75 |
+
)
|
76 |
+
|
77 |
+
output_tokens = 0
|
78 |
+
input_tokens = 0
|
79 |
+
cache_creation_input_tokens = 0
|
80 |
+
cache_read_input_tokens = 0
|
81 |
+
for chunk in response:
|
82 |
+
if chunk.type == "content_block_start":
|
83 |
+
# This is where we might find usage info in the future
|
84 |
+
pass
|
85 |
+
elif chunk.type == "content_block_delta":
|
86 |
+
yield dict(text=chunk.delta.text)
|
87 |
+
elif chunk.type == "message_delta":
|
88 |
+
output_tokens = dict(chunk.usage).get('output_tokens', 0)
|
89 |
+
elif chunk.type == "message_start":
|
90 |
+
usage = chunk.message.usage
|
91 |
+
input_tokens = dict(usage).get('input_tokens', 0)
|
92 |
+
cache_creation_input_tokens = dict(usage).get('cache_creation_input_tokens', 0)
|
93 |
+
cache_read_input_tokens = dict(usage).get('cache_read_input_tokens', 0)
|
94 |
+
else:
|
95 |
+
if verbose:
|
96 |
+
print("Unknown chunk type:", chunk.type)
|
97 |
+
print("Chunk:", chunk)
|
98 |
+
|
99 |
+
if verbose:
|
100 |
+
# After streaming is complete, print the usage information
|
101 |
+
print(f"Output tokens: {output_tokens}")
|
102 |
+
print(f"Input tokens: {input_tokens}")
|
103 |
+
print(f"Cache creation input tokens: {cache_creation_input_tokens}")
|
104 |
+
print(f"Cache read input tokens: {cache_read_input_tokens}")
|
105 |
+
yield dict(output_tokens=output_tokens, input_tokens=input_tokens,
|
106 |
+
cache_creation_input_tokens=cache_creation_input_tokens,
|
107 |
+
cache_read_input_tokens=cache_read_input_tokens)
|
108 |
+
|
109 |
+
|
110 |
+
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(3))
|
111 |
+
def openai_completion_with_backoff(client, *args, **kwargs):
|
112 |
+
return client.chat.completions.create(*args, **kwargs)
|
113 |
+
|
114 |
+
|
115 |
+
def get_openai(model: str,
|
116 |
+
prompt: str,
|
117 |
+
temperature: float = 0,
|
118 |
+
max_tokens: int = 4096,
|
119 |
+
system: str = '',
|
120 |
+
chat_history: List[Dict] = None,
|
121 |
+
verbose=False) -> Generator[dict, None, None]:
|
122 |
+
anthropic_models, openai_models, google_models, groq_models, azure_models, ollama = get_model_names()
|
123 |
+
if model in ollama:
|
124 |
+
model = model.replace('ollama:', '')
|
125 |
+
openai_key = os.getenv('OLLAMA_OPENAI_API_KEY')
|
126 |
+
openai_base_url = os.getenv('OLLAMA_OPENAI_BASE_URL', 'http://localhost:11434/v1/')
|
127 |
+
else:
|
128 |
+
model = model.replace('openai:', '')
|
129 |
+
openai_key = os.getenv('OPENAI_API_KEY')
|
130 |
+
openai_base_url = os.getenv('OPENAI_BASE_URL', 'https://api.openai.com/v1')
|
131 |
+
|
132 |
+
from openai import OpenAI
|
133 |
+
|
134 |
+
openai_client = OpenAI(api_key=openai_key, base_url=openai_base_url) if openai_key else None
|
135 |
+
|
136 |
+
if chat_history is None:
|
137 |
+
chat_history = []
|
138 |
+
chat_history_copy = chat_history.copy()
|
139 |
+
for mi, message in enumerate(chat_history_copy):
|
140 |
+
if isinstance(message["content"], list):
|
141 |
+
chat_history_copy[mi]["content"] = message["content"][0]["text"]
|
142 |
+
chat_history = chat_history_copy
|
143 |
+
|
144 |
+
messages = [{"role": "system", "content": system}] + chat_history + [{"role": "user", "content": prompt}]
|
145 |
+
|
146 |
+
response = openai_completion_with_backoff(openai_client,
|
147 |
+
model=model,
|
148 |
+
messages=messages,
|
149 |
+
temperature=temperature,
|
150 |
+
max_tokens=max_tokens,
|
151 |
+
stream=True,
|
152 |
+
)
|
153 |
+
|
154 |
+
output_tokens = 0
|
155 |
+
input_tokens = 0
|
156 |
+
for chunk in response:
|
157 |
+
if chunk.choices[0].delta.content:
|
158 |
+
yield dict(text=chunk.choices[0].delta.content)
|
159 |
+
if chunk.usage:
|
160 |
+
output_tokens = chunk.usage.completion_tokens
|
161 |
+
input_tokens = chunk.usage.prompt_tokens
|
162 |
+
|
163 |
+
if verbose:
|
164 |
+
print(f"Output tokens: {output_tokens}")
|
165 |
+
print(f"Input tokens: {input_tokens}")
|
166 |
+
yield dict(output_tokens=output_tokens, input_tokens=input_tokens)
|
167 |
+
|
168 |
+
|
169 |
+
def openai_messages_to_gemini_history(messages):
|
170 |
+
"""Converts OpenAI messages to Gemini history format.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
messages: A list of OpenAI messages, each with "role" and "content" keys.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
A list of dictionaries representing the chat history for Gemini.
|
177 |
+
"""
|
178 |
+
history = []
|
179 |
+
for message in messages:
|
180 |
+
if isinstance(message["content"], list):
|
181 |
+
message["content"] = message["content"][0]["text"]
|
182 |
+
if message["role"] == "user":
|
183 |
+
history.append({"role": "user", "parts": [{"text": message["content"]}]})
|
184 |
+
elif message["role"] == "assistant":
|
185 |
+
history.append({"role": "model", "parts": [{"text": message["content"]}]})
|
186 |
+
# Optionally handle system messages if needed
|
187 |
+
# elif message["role"] == "system":
|
188 |
+
# history.append({"role": "system", "parts": [{"text": message["content"]}]})
|
189 |
+
|
190 |
+
return history
|
191 |
+
|
192 |
+
|
193 |
+
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(3))
|
194 |
+
def gemini_send_message_with_backoff(chat, prompt, stream=True):
|
195 |
+
return chat.send_message(prompt, stream=stream)
|
196 |
+
|
197 |
+
|
198 |
+
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(3))
|
199 |
+
def gemini_generate_content_with_backoff(model, prompt, stream=True):
|
200 |
+
return model.generate_content(prompt, stream=stream)
|
201 |
+
|
202 |
+
|
203 |
+
def get_google(model: str,
|
204 |
+
prompt: str,
|
205 |
+
temperature: float = 0,
|
206 |
+
max_tokens: int = 4096,
|
207 |
+
system: str = '',
|
208 |
+
chat_history: List[Dict] = None,
|
209 |
+
verbose=False) -> Generator[dict, None, None]:
|
210 |
+
model = model.replace('google:', '').replace('gemini:', '')
|
211 |
+
|
212 |
+
import google.generativeai as genai
|
213 |
+
|
214 |
+
gemini_key = os.getenv("GEMINI_API_KEY")
|
215 |
+
genai.configure(api_key=gemini_key)
|
216 |
+
# Create the model
|
217 |
+
generation_config = {
|
218 |
+
"temperature": temperature,
|
219 |
+
"top_p": 0.95,
|
220 |
+
"top_k": 64,
|
221 |
+
"max_output_tokens": max_tokens,
|
222 |
+
"response_mime_type": "text/plain",
|
223 |
+
}
|
224 |
+
|
225 |
+
if chat_history is None:
|
226 |
+
chat_history = []
|
227 |
+
|
228 |
+
chat_history = chat_history.copy()
|
229 |
+
chat_history = openai_messages_to_gemini_history(chat_history)
|
230 |
+
|
231 |
+
# NOTE: assume want own control. Too many false positives by Google.
|
232 |
+
from google.generativeai.types import HarmCategory
|
233 |
+
from google.generativeai.types import HarmBlockThreshold
|
234 |
+
safety_settings = {
|
235 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
236 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
237 |
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
238 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
239 |
+
}
|
240 |
+
|
241 |
+
cache = None
|
242 |
+
# disable cache for now until work into things well
|
243 |
+
use_cache = False
|
244 |
+
if use_cache and model == 'gemini-1.5-pro':
|
245 |
+
from google.generativeai import caching
|
246 |
+
# Estimate token count (this is a rough estimate, you may need a more accurate method)
|
247 |
+
estimated_tokens = len(prompt.split()) + sum(len(msg['content'].split()) for msg in chat_history)
|
248 |
+
|
249 |
+
if estimated_tokens > 32000:
|
250 |
+
cache = caching.CachedContent.create(
|
251 |
+
model=model,
|
252 |
+
display_name=f'cache_{datetime.datetime.now().isoformat()}',
|
253 |
+
system_instruction=system,
|
254 |
+
contents=[prompt] + [msg['content'] for msg in chat_history],
|
255 |
+
ttl=datetime.timedelta(minutes=5), # Set an appropriate TTL. Short for now for cost savings.
|
256 |
+
)
|
257 |
+
gemini_model = genai.GenerativeModel.from_cached_content(cached_content=cache)
|
258 |
+
else:
|
259 |
+
gemini_model = genai.GenerativeModel(model_name=model,
|
260 |
+
generation_config=generation_config,
|
261 |
+
safety_settings=safety_settings)
|
262 |
+
else:
|
263 |
+
gemini_model = genai.GenerativeModel(model_name=model,
|
264 |
+
generation_config=generation_config,
|
265 |
+
safety_settings=safety_settings)
|
266 |
+
|
267 |
+
if cache:
|
268 |
+
response = gemini_generate_content_with_backoff(gemini_model, prompt, stream=True)
|
269 |
+
else:
|
270 |
+
chat = gemini_model.start_chat(history=chat_history)
|
271 |
+
response = gemini_send_message_with_backoff(chat, prompt, stream=True)
|
272 |
+
|
273 |
+
output_tokens = 0
|
274 |
+
input_tokens = 0
|
275 |
+
cache_read_input_tokens = 0
|
276 |
+
cache_creation_input_tokens = 0
|
277 |
+
|
278 |
+
for chunk in response:
|
279 |
+
if chunk.text:
|
280 |
+
yield dict(text=chunk.text)
|
281 |
+
if chunk.usage_metadata:
|
282 |
+
output_tokens = chunk.usage_metadata.candidates_token_count
|
283 |
+
input_tokens = chunk.usage_metadata.prompt_token_count
|
284 |
+
cache_read_input_tokens = chunk.usage_metadata.cached_content_token_count
|
285 |
+
cache_creation_input_tokens = 0 # This might need to be updated if available in the API
|
286 |
+
|
287 |
+
if verbose:
|
288 |
+
print(f"Output tokens: {output_tokens}")
|
289 |
+
print(f"Input tokens: {input_tokens}")
|
290 |
+
print(f"Cached tokens: {cache_read_input_tokens}")
|
291 |
+
|
292 |
+
yield dict(output_tokens=output_tokens, input_tokens=input_tokens,
|
293 |
+
cache_read_input_tokens=cache_read_input_tokens,
|
294 |
+
cache_creation_input_tokens=cache_creation_input_tokens)
|
295 |
+
|
296 |
+
|
297 |
+
def delete_cache(cache):
|
298 |
+
if cache:
|
299 |
+
cache.delete()
|
300 |
+
print(f"Cache {cache.display_name} deleted.")
|
301 |
+
else:
|
302 |
+
print("No cache to delete.")
|
303 |
+
|
304 |
+
|
305 |
+
def get_groq(model: str,
|
306 |
+
prompt: str,
|
307 |
+
temperature: float = 0,
|
308 |
+
max_tokens: int = 4096,
|
309 |
+
system: str = '',
|
310 |
+
chat_history: List[Dict] = None,
|
311 |
+
verbose=False) -> Generator[dict, None, None]:
|
312 |
+
model = model.replace('groq:', '')
|
313 |
+
|
314 |
+
from groq import Groq
|
315 |
+
|
316 |
+
groq_key = os.getenv("GROQ_API_KEY")
|
317 |
+
client = Groq(api_key=groq_key)
|
318 |
+
|
319 |
+
if chat_history is None:
|
320 |
+
chat_history = []
|
321 |
+
|
322 |
+
chat_history = chat_history.copy()
|
323 |
+
|
324 |
+
messages = [{"role": "system", "content": system}] + chat_history + [{"role": "user", "content": prompt}]
|
325 |
+
|
326 |
+
stream = openai_completion_with_backoff(client,
|
327 |
+
messages=messages,
|
328 |
+
model=model,
|
329 |
+
temperature=temperature,
|
330 |
+
max_tokens=max_tokens,
|
331 |
+
stream=True,
|
332 |
+
)
|
333 |
+
|
334 |
+
output_tokens = 0
|
335 |
+
input_tokens = 0
|
336 |
+
for chunk in stream:
|
337 |
+
if chunk.choices[0].delta.content:
|
338 |
+
yield dict(text=chunk.choices[0].delta.content)
|
339 |
+
if chunk.usage:
|
340 |
+
output_tokens = chunk.usage.completion_tokens
|
341 |
+
input_tokens = chunk.usage.prompt_tokens
|
342 |
+
|
343 |
+
if verbose:
|
344 |
+
print(f"Output tokens: {output_tokens}")
|
345 |
+
print(f"Input tokens: {input_tokens}")
|
346 |
+
yield dict(output_tokens=output_tokens, input_tokens=input_tokens)
|
347 |
+
|
348 |
+
|
349 |
+
def get_openai_azure(model: str,
|
350 |
+
prompt: str,
|
351 |
+
temperature: float = 0,
|
352 |
+
max_tokens: int = 4096,
|
353 |
+
system: str = '',
|
354 |
+
chat_history: List[Dict] = None,
|
355 |
+
verbose=False) -> Generator[dict, None, None]:
|
356 |
+
model = model.replace('azure:', '').replace('openai_azure:', '')
|
357 |
+
|
358 |
+
from openai import AzureOpenAI
|
359 |
+
|
360 |
+
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") # e.g. https://project.openai.azure.com
|
361 |
+
azure_key = os.getenv("AZURE_OPENAI_API_KEY")
|
362 |
+
azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") # i.e. deployment name with some models deployed
|
363 |
+
azure_api_version = os.getenv('AZURE_OPENAI_API_VERSION', '2024-07-01-preview')
|
364 |
+
assert azure_endpoint is not None, "Azure OpenAI endpoint not set"
|
365 |
+
assert azure_key is not None, "Azure OpenAI API key not set"
|
366 |
+
assert azure_deployment is not None, "Azure OpenAI deployment not set"
|
367 |
+
|
368 |
+
client = AzureOpenAI(
|
369 |
+
azure_endpoint=azure_endpoint,
|
370 |
+
api_key=azure_key,
|
371 |
+
api_version=azure_api_version,
|
372 |
+
azure_deployment=azure_deployment,
|
373 |
+
)
|
374 |
+
|
375 |
+
if chat_history is None:
|
376 |
+
chat_history = []
|
377 |
+
|
378 |
+
messages = [{"role": "system", "content": system}] + chat_history + [{"role": "user", "content": prompt}]
|
379 |
+
|
380 |
+
response = openai_completion_with_backoff(client,
|
381 |
+
model=model,
|
382 |
+
messages=messages,
|
383 |
+
temperature=temperature,
|
384 |
+
max_tokens=max_tokens,
|
385 |
+
stream=True
|
386 |
+
)
|
387 |
+
|
388 |
+
output_tokens = 0
|
389 |
+
input_tokens = 0
|
390 |
+
for chunk in response:
|
391 |
+
if chunk.choices and chunk.choices[0].delta.content:
|
392 |
+
yield dict(text=chunk.choices[0].delta.content)
|
393 |
+
if chunk.usage:
|
394 |
+
output_tokens = chunk.usage.completion_tokens
|
395 |
+
input_tokens = chunk.usage.prompt_tokens
|
396 |
+
|
397 |
+
if verbose:
|
398 |
+
print(f"Output tokens: {output_tokens}")
|
399 |
+
print(f"Input tokens: {input_tokens}")
|
400 |
+
yield dict(output_tokens=output_tokens, input_tokens=input_tokens)
|
401 |
+
|
402 |
+
|
403 |
+
def to_list(x):
|
404 |
+
if x:
|
405 |
+
try:
|
406 |
+
ollama_model_list = ast.literal_eval(x)
|
407 |
+
assert isinstance(ollama_model_list, list)
|
408 |
+
except:
|
409 |
+
x = [x]
|
410 |
+
else:
|
411 |
+
x = []
|
412 |
+
return x
|
413 |
+
|
414 |
+
|
415 |
+
def get_model_names(secrets, on_hf_spaces=False):
|
416 |
+
if not on_hf_spaces:
|
417 |
+
secrets = os.environ
|
418 |
+
if secrets.get('ANTHROPIC_API_KEY'):
|
419 |
+
anthropic_models = ['claude-3-5-sonnet-20240620', 'claude-3-haiku-20240307', 'claude-3-opus-20240229']
|
420 |
+
else:
|
421 |
+
anthropic_models = []
|
422 |
+
if secrets.get('OPENAI_API_KEY'):
|
423 |
+
if os.getenv('OPENAI_MODEL_NAME'):
|
424 |
+
openai_models = to_list(os.getenv('OPENAI_MODEL_NAME'))
|
425 |
+
else:
|
426 |
+
openai_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
|
427 |
+
else:
|
428 |
+
openai_models = []
|
429 |
+
if secrets.get('AZURE_OPENAI_API_KEY'):
|
430 |
+
if os.getenv('AZURE_OPENAI_MODEL_NAME'):
|
431 |
+
azure_models = to_list(os.getenv('AZURE_OPENAI_MODEL_NAME'))
|
432 |
+
else:
|
433 |
+
azure_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
|
434 |
+
else:
|
435 |
+
azure_models = []
|
436 |
+
if secrets.get('GEMINI_API_KEY'):
|
437 |
+
google_models = ['gemini-1.5-pro-latest', 'gemini-1.5-flash-latest']
|
438 |
+
else:
|
439 |
+
google_models = []
|
440 |
+
if secrets.get('GROQ_API_KEY'):
|
441 |
+
groq_models = ['llama-3.1-70b-versatile',
|
442 |
+
'llama-3.1-8b-instant',
|
443 |
+
'llama3-groq-70b-8192-tool-use-preview',
|
444 |
+
'llama3-groq-8b-8192-tool-use-preview',
|
445 |
+
'mixtral-8x7b-32768']
|
446 |
+
else:
|
447 |
+
groq_models = []
|
448 |
+
if secrets.get('OLLAMA_OPENAI_API_KEY'):
|
449 |
+
ollama_model = os.environ['OLLAMA_OPENAI_MODEL_NAME']
|
450 |
+
ollama_model = to_list(ollama_model)
|
451 |
+
else:
|
452 |
+
ollama_model = []
|
453 |
+
|
454 |
+
groq_models = ['groq:' + x for x in groq_models]
|
455 |
+
azure_models = ['azure:' + x for x in azure_models]
|
456 |
+
openai_models = ['openai:' + x for x in openai_models]
|
457 |
+
google_models = ['google:' + x for x in google_models]
|
458 |
+
anthropic_models = ['anthropic:' + x for x in anthropic_models]
|
459 |
+
ollama = ['ollama:' + x if 'ollama:' not in x else x for x in ollama_model]
|
460 |
+
|
461 |
+
return anthropic_models, openai_models, google_models, groq_models, azure_models, ollama
|
462 |
+
|
463 |
+
|
464 |
+
def get_all_model_names(secrets, on_hf_spaces=False):
|
465 |
+
anthropic_models, openai_models, google_models, groq_models, azure_models, ollama = get_model_names(secrets,
|
466 |
+
on_hf_spaces=on_hf_spaces)
|
467 |
+
return anthropic_models + openai_models + google_models + groq_models + azure_models + ollama
|
468 |
+
|
469 |
+
|
470 |
+
def get_model_api(model: str):
|
471 |
+
if model.startswith('anthropic:'):
|
472 |
+
return get_anthropic
|
473 |
+
elif model.startswith('openai:') or model.startswith('ollama:'):
|
474 |
+
return get_openai
|
475 |
+
elif model.startswith('google:'):
|
476 |
+
return get_google
|
477 |
+
elif model.startswith('groq:'):
|
478 |
+
return get_groq
|
479 |
+
elif model.startswith('azure:'):
|
480 |
+
return get_openai_azure
|
481 |
+
else:
|
482 |
+
raise ValueError(
|
483 |
+
f"Unsupported model: {model}. Ensure to add prefix (e.g. openai:, google:, groq:, azure:, ollama:, anthropic:)")
|
open_strawberry.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
from typing import List, Dict, Generator, Tuple
|
5 |
+
from collections import deque
|
6 |
+
|
7 |
+
try:
|
8 |
+
from src.models import get_model_api
|
9 |
+
from src.utils import get_turn_title, get_final_answer, get_xml_tag_value
|
10 |
+
except (ModuleNotFoundError, ImportError):
|
11 |
+
from models import get_model_api
|
12 |
+
from utils import get_turn_title, get_final_answer, get_xml_tag_value
|
13 |
+
|
14 |
+
|
15 |
+
class DeductionTracker:
|
16 |
+
def __init__(self):
|
17 |
+
self.deductions = []
|
18 |
+
self.certainty_scores = []
|
19 |
+
|
20 |
+
def add_deduction(self, deduction: str, certainty: float):
|
21 |
+
self.deductions.append(deduction)
|
22 |
+
self.certainty_scores.append(certainty)
|
23 |
+
|
24 |
+
def get_deductions(self):
|
25 |
+
return list(zip(self.deductions, self.certainty_scores))
|
26 |
+
|
27 |
+
def update_certainty(self, index: int, new_certainty: float):
|
28 |
+
if 0 <= index < len(self.certainty_scores):
|
29 |
+
self.certainty_scores[index] = new_certainty
|
30 |
+
|
31 |
+
|
32 |
+
class ProblemRepresentation:
|
33 |
+
def __init__(self):
|
34 |
+
self.current_representation = ""
|
35 |
+
|
36 |
+
def update(self, new_representation: str):
|
37 |
+
self.current_representation = new_representation
|
38 |
+
|
39 |
+
def get(self) -> str:
|
40 |
+
return self.current_representation
|
41 |
+
|
42 |
+
|
43 |
+
def get_last_assistant_responses(chat_history, n=3):
|
44 |
+
assistant_messages = [msg['content'] for msg in chat_history if msg['role'] == 'assistant']
|
45 |
+
return assistant_messages[-n:]
|
46 |
+
|
47 |
+
|
48 |
+
def generate_dynamic_system_prompt(base_prompt, turn_count: int, problem_complexity: float, problem_representation: str,
|
49 |
+
deductions: List[Tuple[str, float]]) -> str:
|
50 |
+
dynamic_prompt = base_prompt + "\n\n* Always refer to and update the current problem representation as needed."
|
51 |
+
dynamic_prompt += "\n* Maintain and update the list of deductions and their certainty scores."
|
52 |
+
|
53 |
+
if turn_count > 20:
|
54 |
+
dynamic_prompt += "\n* At this stage, focus on synthesizing your previous thoughts and looking for breakthrough insights."
|
55 |
+
|
56 |
+
if problem_complexity > 0.7:
|
57 |
+
dynamic_prompt += "\n* This is a highly complex problem. Consider breaking it down into smaller subproblems and solving them incrementally."
|
58 |
+
|
59 |
+
dynamic_prompt += "\n* Regularly verify that your current understanding satisfies ALL given clues."
|
60 |
+
dynamic_prompt += "\n* If you reach a contradiction, backtrack to the last point where you were certain and explore alternative paths."
|
61 |
+
|
62 |
+
dynamic_prompt += f"\n\nCurrent problem representation:\n{problem_representation}"
|
63 |
+
dynamic_prompt += "\n\nYou can update this representation by providing a new one within <representation></representation> tags."
|
64 |
+
|
65 |
+
dynamic_prompt += "\n\nCurrent deductions and certainty scores:"
|
66 |
+
for deduction, certainty in deductions:
|
67 |
+
dynamic_prompt += f"\n- {deduction} (Certainty: {certainty})"
|
68 |
+
dynamic_prompt += "\n\nYou can add new deductions or update existing ones using <deduction></deduction> and <certainty></certainty> tags."
|
69 |
+
|
70 |
+
return dynamic_prompt
|
71 |
+
|
72 |
+
|
73 |
+
def generate_initial_representation_prompt(initial_prompt: str) -> str:
|
74 |
+
return f"""Based on the following problem description:
|
75 |
+
{initial_prompt}
|
76 |
+
|
77 |
+
Representation:
|
78 |
+
* Create a clear and clean representation that breaks down the problem into parts in order to create a structure that helps track solving the problem.
|
79 |
+
* Put the representation inside <representation> </representation> XML tags, ensuring to add new lines before and after XML tags.
|
80 |
+
* The representation could be a table, matrix, grid, or any other format that breaks down the problem into its components and ensure it helps iteratively track progress towards the solution.
|
81 |
+
* Example representations include a matrix that has values of digits as rows and position of digits as columns and values at each row-column as tracking confirmed position, eliminated position, or possible position.
|
82 |
+
* For a table or grid representation, you must put the table or grid inside a Markdown code block (with new lines around the back ticks) and make it nice and easy to read for a human to understand it.
|
83 |
+
|
84 |
+
Deductions:
|
85 |
+
* Provide your initial deductions (if any) using <deduction></deduction> tags, each followed by a certainty score in <certainty></certainty> tags (0-100).
|
86 |
+
"""
|
87 |
+
|
88 |
+
|
89 |
+
def generate_verification_prompt(chat_history: List[Dict], turn_count: int, problem_representation: str,
|
90 |
+
deductions: List[Tuple[str, float]]) -> str:
|
91 |
+
last_responses = get_last_assistant_responses(chat_history, n=5)
|
92 |
+
|
93 |
+
verification_prompt = f"""Turn {turn_count}: Comprehensive Verification and Critique
|
94 |
+
|
95 |
+
1. Review your previous reasoning steps:
|
96 |
+
{' '.join(last_responses)}
|
97 |
+
|
98 |
+
2. Current problem representation:
|
99 |
+
{problem_representation}
|
100 |
+
|
101 |
+
3. Current deductions and certainty scores:
|
102 |
+
{deductions}
|
103 |
+
|
104 |
+
4. Perform the following checks:
|
105 |
+
a) Identify any logical fallacies or unjustified assumptions
|
106 |
+
b) Check for mathematical or factual errors
|
107 |
+
c) Assess the relevance of each step to the main problem
|
108 |
+
d) Evaluate the coherence and consistency of your reasoning
|
109 |
+
e) Verify that your current understanding satisfies ALL given clues
|
110 |
+
f) Check if any of your deductions contradict each other
|
111 |
+
|
112 |
+
5. If you find any issues:
|
113 |
+
a) Explain the issue in detail
|
114 |
+
b) Correct the error or resolve the contradiction
|
115 |
+
c) Update the problem representation if necessary
|
116 |
+
d) Update deductions and certainty scores as needed
|
117 |
+
|
118 |
+
6. If no issues are found, suggest a new approach or perspective to consider.
|
119 |
+
|
120 |
+
7. Assign an overall confidence score (0-100) to your current reasoning path and explain why.
|
121 |
+
|
122 |
+
Respond in this format:
|
123 |
+
<verification>
|
124 |
+
[Your detailed verification and critique]
|
125 |
+
</verification>
|
126 |
+
<updates>
|
127 |
+
[Any updates to the problem representation or deductions]
|
128 |
+
</updates>
|
129 |
+
<confidence_score>[0-100]</confidence_score>
|
130 |
+
<explanation>[Explanation for the confidence score]</explanation>
|
131 |
+
|
132 |
+
If you need to update the problem representation, provide the new representation within <representation></representation> tags.
|
133 |
+
For new or updated deductions, use <deduction></deduction> tags, each followed by <certainty></certainty> tags.
|
134 |
+
"""
|
135 |
+
return verification_prompt
|
136 |
+
|
137 |
+
|
138 |
+
def generate_hypothesis_prompt(chat_history: List[Dict]) -> str:
|
139 |
+
return """Based on your current understanding of the problem:
|
140 |
+
1. Generate three distinct hypotheses that could lead to a solution.
|
141 |
+
2. For each hypothesis, provide a brief rationale and a potential test to validate it.
|
142 |
+
3. Rank these hypotheses in order of perceived promise.
|
143 |
+
|
144 |
+
Respond in this format:
|
145 |
+
<hypotheses>
|
146 |
+
1. [Hypothesis 1]
|
147 |
+
Rationale: [Brief explanation]
|
148 |
+
Test: [Proposed validation method]
|
149 |
+
|
150 |
+
2. [Hypothesis 2]
|
151 |
+
Rationale: [Brief explanation]
|
152 |
+
Test: [Proposed validation method]
|
153 |
+
|
154 |
+
3. [Hypothesis 3]
|
155 |
+
Rationale: [Brief explanation]
|
156 |
+
Test: [Proposed validation method]
|
157 |
+
</hypotheses>
|
158 |
+
<ranking>[Your ranking and brief justification]</ranking>
|
159 |
+
"""
|
160 |
+
|
161 |
+
|
162 |
+
def generate_analogical_reasoning_prompt(problem_description: str) -> str:
|
163 |
+
return f"""Consider the following problem:
|
164 |
+
{problem_description}
|
165 |
+
|
166 |
+
Now, think of an analogous problem from a different domain that shares similar structural characteristics. Describe:
|
167 |
+
1. The analogous problem
|
168 |
+
2. The key similarities between the original and analogous problems
|
169 |
+
3. How the solution to the analogous problem might inform our approach to the original problem
|
170 |
+
|
171 |
+
Respond in this format:
|
172 |
+
<analogy>
|
173 |
+
Problem: [Description of the analogous problem]
|
174 |
+
Similarities: [Key structural similarities]
|
175 |
+
Insights: [How this analogy might help solve the original problem]
|
176 |
+
</analogy>
|
177 |
+
"""
|
178 |
+
|
179 |
+
|
180 |
+
def generate_metacognitive_prompt() -> str:
|
181 |
+
return """Take a step back and reflect on your problem-solving process:
|
182 |
+
1. What strategies have been most effective so far?
|
183 |
+
2. What are the main obstacles you're facing?
|
184 |
+
3. How might you adjust your approach to overcome these obstacles?
|
185 |
+
4. What assumptions might you be making that could be limiting your progress?
|
186 |
+
|
187 |
+
Respond in this format:
|
188 |
+
<metacognition>
|
189 |
+
Effective Strategies: [List and brief explanation]
|
190 |
+
Main Obstacles: [List and brief explanation]
|
191 |
+
Proposed Adjustments: [List of potential changes to your approach]
|
192 |
+
Potential Limiting Assumptions: [List and brief explanation]
|
193 |
+
</metacognition>
|
194 |
+
"""
|
195 |
+
|
196 |
+
|
197 |
+
def generate_devils_advocate_prompt(current_approach: str) -> str:
|
198 |
+
return f"""Consider your current approach:
|
199 |
+
{current_approach}
|
200 |
+
|
201 |
+
Now, play the role of a skeptical critic:
|
202 |
+
1. What are the three strongest arguments against this approach?
|
203 |
+
2. What critical information might we be overlooking?
|
204 |
+
3. How might this approach fail in extreme or edge cases?
|
205 |
+
|
206 |
+
Respond in this format:
|
207 |
+
<devils_advocate>
|
208 |
+
Counter-arguments:
|
209 |
+
1. [First strong counter-argument]
|
210 |
+
2. [Second strong counter-argument]
|
211 |
+
3. [Third strong counter-argument]
|
212 |
+
|
213 |
+
Overlooked Information: [Potential critical information we might be missing]
|
214 |
+
|
215 |
+
Potential Failures: [How this approach might fail in extreme or edge cases]
|
216 |
+
</devils_advocate>
|
217 |
+
"""
|
218 |
+
|
219 |
+
|
220 |
+
def generate_hint(problem_description: str, current_progress: str, difficulty: float) -> str:
|
221 |
+
if difficulty < 0.3:
|
222 |
+
hint_level = "subtle"
|
223 |
+
elif difficulty < 0.7:
|
224 |
+
hint_level = "moderate"
|
225 |
+
else:
|
226 |
+
hint_level = "strong"
|
227 |
+
|
228 |
+
return f"""Based on the original problem:
|
229 |
+
{problem_description}
|
230 |
+
|
231 |
+
And the current progress:
|
232 |
+
{current_progress}
|
233 |
+
|
234 |
+
Provide a {hint_level} hint to help move towards the solution without giving it away entirely.
|
235 |
+
|
236 |
+
<hint>
|
237 |
+
[Your {hint_level} hint here]
|
238 |
+
</hint>
|
239 |
+
"""
|
240 |
+
|
241 |
+
|
242 |
+
def summarize_and_restructure(chat_history: List[Dict]) -> str:
|
243 |
+
return """Review the entire conversation history and provide:
|
244 |
+
1. A concise summary of the key insights and progress made so far
|
245 |
+
2. A restructured presentation of the problem based on our current understanding
|
246 |
+
3. Identification of any patterns or recurring themes in our problem-solving attempts
|
247 |
+
|
248 |
+
Respond in this format:
|
249 |
+
<summary_and_restructure>
|
250 |
+
Key Insights: [Bullet point list of main insights]
|
251 |
+
Restructured Problem: [Revised problem statement based on current understanding]
|
252 |
+
Patterns/Themes: [Identified patterns or recurring themes in our approach]
|
253 |
+
</summary_and_restructure>
|
254 |
+
"""
|
255 |
+
|
256 |
+
|
257 |
+
class Memory:
|
258 |
+
def __init__(self, max_size=10):
|
259 |
+
self.insights = deque(maxlen=max_size)
|
260 |
+
self.mistakes = deque(maxlen=max_size)
|
261 |
+
self.dead_ends = deque(maxlen=max_size)
|
262 |
+
|
263 |
+
def add_insight(self, insight: str):
|
264 |
+
self.insights.append(insight)
|
265 |
+
|
266 |
+
def add_mistake(self, mistake: str):
|
267 |
+
self.mistakes.append(mistake)
|
268 |
+
|
269 |
+
def add_dead_end(self, dead_end: str):
|
270 |
+
self.dead_ends.append(dead_end)
|
271 |
+
|
272 |
+
def get_insights(self) -> List[str]:
|
273 |
+
return list(self.insights)
|
274 |
+
|
275 |
+
def get_mistakes(self) -> List[str]:
|
276 |
+
return list(self.mistakes)
|
277 |
+
|
278 |
+
def get_dead_ends(self) -> List[str]:
|
279 |
+
return list(self.dead_ends)
|
280 |
+
|
281 |
+
|
282 |
+
def manage_conversation(model: str,
|
283 |
+
system: str,
|
284 |
+
initial_prompt: str,
|
285 |
+
next_prompts: List[str],
|
286 |
+
final_prompt: str = "",
|
287 |
+
num_turns: int = 25,
|
288 |
+
num_turns_final_mod: int = 9,
|
289 |
+
cli_mode: bool = False,
|
290 |
+
temperature: float = 0.3,
|
291 |
+
max_tokens: int = 4096,
|
292 |
+
seed: int = 1234,
|
293 |
+
verbose: bool = False
|
294 |
+
) -> Generator[Dict, None, list]:
|
295 |
+
if seed == 0:
|
296 |
+
seed = random.randint(0, 1000000)
|
297 |
+
random.seed(seed)
|
298 |
+
get_model_func = get_model_api(model)
|
299 |
+
chat_history = []
|
300 |
+
memory = Memory()
|
301 |
+
problem_representation = ProblemRepresentation()
|
302 |
+
deduction_tracker = DeductionTracker()
|
303 |
+
|
304 |
+
turn_count = 0
|
305 |
+
total_thinking_time = 0
|
306 |
+
problem_complexity = 0.5 # Initial estimate, will be dynamically updated
|
307 |
+
|
308 |
+
base_system = system
|
309 |
+
while True:
|
310 |
+
system = generate_dynamic_system_prompt(base_system, turn_count, problem_complexity,
|
311 |
+
problem_representation.get(), deduction_tracker.get_deductions())
|
312 |
+
trying_final = False
|
313 |
+
|
314 |
+
if turn_count == 0:
|
315 |
+
prompt = generate_initial_representation_prompt(initial_prompt)
|
316 |
+
elif turn_count % 5 == 0:
|
317 |
+
prompt = generate_verification_prompt(chat_history, turn_count, problem_representation.get(),
|
318 |
+
deduction_tracker.get_deductions())
|
319 |
+
elif turn_count % 7 == 0:
|
320 |
+
prompt = generate_hypothesis_prompt(chat_history)
|
321 |
+
elif turn_count % 11 == 0:
|
322 |
+
prompt = generate_analogical_reasoning_prompt(initial_prompt)
|
323 |
+
elif turn_count % 13 == 0:
|
324 |
+
prompt = generate_metacognitive_prompt()
|
325 |
+
elif turn_count % 17 == 0:
|
326 |
+
current_approach = get_last_assistant_responses(chat_history, n=1)[0]
|
327 |
+
prompt = generate_devils_advocate_prompt(current_approach)
|
328 |
+
elif turn_count % 19 == 0:
|
329 |
+
current_progress = get_last_assistant_responses(chat_history, n=3)
|
330 |
+
prompt = generate_hint(initial_prompt, "\n".join(current_progress), problem_complexity)
|
331 |
+
elif turn_count % 23 == 0:
|
332 |
+
prompt = summarize_and_restructure(chat_history)
|
333 |
+
elif turn_count % num_turns_final_mod == 0 and turn_count > 0:
|
334 |
+
trying_final = True
|
335 |
+
prompt = final_prompt
|
336 |
+
else:
|
337 |
+
prompt = random.choice(next_prompts)
|
338 |
+
|
339 |
+
if turn_count == 0:
|
340 |
+
yield {"role": "user", "content": initial_prompt, "chat_history": chat_history, "initial": turn_count == 0}
|
341 |
+
else:
|
342 |
+
yield {"role": "user", "content": prompt, "chat_history": chat_history, "initial": turn_count == 0}
|
343 |
+
|
344 |
+
thinking_time = time.time()
|
345 |
+
response_text = ''
|
346 |
+
for chunk in get_model_func(model, prompt, system=system, chat_history=chat_history,
|
347 |
+
temperature=temperature, max_tokens=max_tokens, verbose=verbose):
|
348 |
+
if 'text' in chunk and chunk['text']:
|
349 |
+
response_text += chunk['text']
|
350 |
+
yield {"role": "assistant", "content": chunk['text'], "streaming": True, "chat_history": chat_history,
|
351 |
+
"final": False, "turn_title": False}
|
352 |
+
else:
|
353 |
+
yield {"role": "usage", "content": chunk}
|
354 |
+
thinking_time = time.time() - thinking_time
|
355 |
+
total_thinking_time += thinking_time
|
356 |
+
|
357 |
+
# Update problem complexity based on thinking time and response length
|
358 |
+
problem_complexity = min(1.0,
|
359 |
+
problem_complexity + (thinking_time / 60) * 0.1 + (len(response_text) / 1000) * 0.05)
|
360 |
+
|
361 |
+
# Extract and update problem representation
|
362 |
+
representations = get_xml_tag_value(response_text, 'representation', ret_all=False)
|
363 |
+
if representations:
|
364 |
+
problem_representation.update(representations[-1])
|
365 |
+
|
366 |
+
# Extract and update deductions
|
367 |
+
deductions = get_xml_tag_value(response_text, 'deduction')
|
368 |
+
for deduction in deductions:
|
369 |
+
certainties = get_xml_tag_value(response_text, 'certainty', ret_all=False)
|
370 |
+
if certainties:
|
371 |
+
deduction_tracker.add_deduction(deduction, certainties[-1])
|
372 |
+
|
373 |
+
# Extract insights, mistakes, and dead ends from the response
|
374 |
+
[memory.add_insight(x) for x in get_xml_tag_value(response_text, 'insight')]
|
375 |
+
[memory.add_mistake(x) for x in get_xml_tag_value(response_text, 'mistake')]
|
376 |
+
[memory.add_dead_end(x) for x in get_xml_tag_value(response_text, 'dead_end')]
|
377 |
+
|
378 |
+
turn_title = get_turn_title(response_text)
|
379 |
+
yield {"role": "assistant", "content": turn_title, "turn_title": True, 'thinking_time': thinking_time,
|
380 |
+
'total_thinking_time': total_thinking_time}
|
381 |
+
|
382 |
+
chat_history.append(
|
383 |
+
{"role": "user",
|
384 |
+
"content": [{"type": "text", "text": prompt, "cache_control": {"type": "ephemeral"}}]})
|
385 |
+
chat_history.append({"role": "assistant", "content": response_text})
|
386 |
+
|
387 |
+
# Adjusted to only check final answer when trying_final is True
|
388 |
+
always_check_final = False
|
389 |
+
if trying_final or always_check_final:
|
390 |
+
final_value = get_final_answer(response_text, cli_mode=cli_mode)
|
391 |
+
if final_value:
|
392 |
+
chat_history.append({"role": "assistant", "content": final_value})
|
393 |
+
yield {"role": "assistant", "content": final_value, "streaming": True, "chat_history": chat_history,
|
394 |
+
"final": True}
|
395 |
+
break
|
396 |
+
|
397 |
+
turn_count += 1
|
398 |
+
|
399 |
+
# Dynamically adjust temperature based on progress
|
400 |
+
if turn_count % 10 == 0:
|
401 |
+
temperature = min(1.0, temperature + 0.1) # Gradually increase temperature to encourage exploration
|
402 |
+
|
403 |
+
if turn_count % num_turns == 0:
|
404 |
+
# periodically pause for continuation, never have to fully terminate
|
405 |
+
if cli_mode:
|
406 |
+
user_continue = input("\nContinue? (y/n): ").lower() == 'y'
|
407 |
+
if not user_continue:
|
408 |
+
break
|
409 |
+
else:
|
410 |
+
yield {"role": "action", "content": "continue?", "chat_history": chat_history}
|
411 |
+
|
412 |
+
time.sleep(0.001)
|
413 |
+
|
414 |
+
|
415 |
+
def get_defaults() -> Tuple:
|
416 |
+
initial_prompt = """Can you crack the code?
|
417 |
+
9 2 8 5 (One number is correct but in the wrong position)
|
418 |
+
1 9 3 7 (Two numbers are correct but in the wrong positions)
|
419 |
+
5 2 0 1 (one number is correct and in the right position)
|
420 |
+
6 5 0 7 (nothing is correct)
|
421 |
+
8 5 2 4 (two numbers are correct but in the wrong positions)"""
|
422 |
+
|
423 |
+
expected_answer = "3841"
|
424 |
+
|
425 |
+
system_prompt = """Let us play a game of "take only the most minuscule step toward the solution."
|
426 |
+
<thinking_game>
|
427 |
+
* The assistant's text output must be only the very next possible step.
|
428 |
+
* Use your text output as a scratch pad in addition to a literal output of some next step.
|
429 |
+
* Every time you make a major shift in thinking, output your high-level current thinking in <thinking> </thinking> XML tags.
|
430 |
+
* You should present your response in a way that iterates on that scratch pad space with surrounding textual context.
|
431 |
+
* You win the game if you are able to take the smallest text steps possible while still (on average) heading towards the solution.
|
432 |
+
* Backtracking is allowed, and generating python code is allowed (but will not be executed, but can be used to think), just on average over many text output turns you must head towards the answer.
|
433 |
+
* You must think using first principles, and ensure you identify inconsistencies, errors, etc.
|
434 |
+
* Periodically, you should review your previous reasoning steps and check for errors or inconsistencies. If you find any, correct them.
|
435 |
+
* You MUST always end with a very brief natural language title (it should just describe the analysis, do not give step numbers) of what you did inside <turn_title> </turn_title> XML tags. Only a single title is allowed.
|
436 |
+
* Do not provide the final answer unless the user specifically requests it using the final prompt.
|
437 |
+
</thinking_game>
|
438 |
+
Remember to compensate for your flaws:
|
439 |
+
<system_flaws>
|
440 |
+
* Flaw 1: Bad at counting due to tokenization issues. Expand word with spaces between first and then only count that expanded version.
|
441 |
+
* Flaw 2: Grade school or advanced math. Solve such problems very carefully step-by-step.
|
442 |
+
</system_flaws>
|
443 |
+
"""
|
444 |
+
|
445 |
+
next_prompts = [
|
446 |
+
"Continue your effort to answer the original query. What's your next step?",
|
447 |
+
"What aspect of the problem haven't we considered yet?",
|
448 |
+
"Can you identify any patterns or relationships in the given information?",
|
449 |
+
"How would you verify your current reasoning?",
|
450 |
+
"What's the weakest part of your current approach? How can we strengthen it?",
|
451 |
+
"If you were to explain your current thinking to a novice, what would you say?",
|
452 |
+
"What alternative perspectives could we consider?",
|
453 |
+
"How does your current approach align with the constraints of the problem?",
|
454 |
+
"What assumptions are we making? Are they all necessary?",
|
455 |
+
"If we were to start over with our current knowledge, how would our approach differ?",
|
456 |
+
]
|
457 |
+
|
458 |
+
final_prompt = """Verification checklist:
|
459 |
+
1) Do you have very high confidence in a final answer?
|
460 |
+
2) Have you fully verified your answer with all the time and resources you have?
|
461 |
+
3) If you have very high confidence AND you have fully verified your answer with all resources possible, then put the final answer in <final_answer> </final_answer> XML tags, otherwise please continue to vigorously work on the user's original query.
|
462 |
+
"""
|
463 |
+
|
464 |
+
num_turns = int(os.getenv('NUM_TURNS', '10')) # Number of turns before pausing for continuation
|
465 |
+
num_turns_final_mod = num_turns - 1 # Not required, just an OK value. Could be randomized.
|
466 |
+
|
467 |
+
show_next = False
|
468 |
+
show_cot = False
|
469 |
+
verbose = False
|
470 |
+
|
471 |
+
# model = "claude-3-5-sonnet-20240620"
|
472 |
+
model = "anthropic:claude-3-haiku-20240307"
|
473 |
+
|
474 |
+
temperature = 0.3
|
475 |
+
max_tokens = 4096
|
476 |
+
|
477 |
+
return (model, system_prompt,
|
478 |
+
initial_prompt,
|
479 |
+
expected_answer,
|
480 |
+
next_prompts,
|
481 |
+
num_turns, show_next, final_prompt,
|
482 |
+
temperature, max_tokens,
|
483 |
+
num_turns_final_mod,
|
484 |
+
show_cot,
|
485 |
+
verbose)
|
486 |
+
|
487 |
+
|
488 |
+
if __name__ == '__main__':
|
489 |
+
from src.cli import go_cli
|
490 |
+
|
491 |
+
go_cli()
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python-dotenv
|
2 |
+
|
3 |
+
anthropic
|
4 |
+
openai
|
5 |
+
google-generativeai
|
6 |
+
groq
|
7 |
+
|
8 |
+
streamlit
|
9 |
+
blinker
|
10 |
+
click
|
11 |
+
pydantic
|
12 |
+
|
13 |
+
tenacity
|
utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
def get_turn_title(response_text):
|
5 |
+
tag = 'turn_title'
|
6 |
+
pattern = fr'<{tag}>(.*?)</{tag}>'
|
7 |
+
values = re.findall(pattern, response_text, re.DOTALL)
|
8 |
+
values0 = values.copy()
|
9 |
+
values = [v.strip() for v in values]
|
10 |
+
values = [v for v in values if v]
|
11 |
+
if len(values) == 0:
|
12 |
+
# then maybe removed too much
|
13 |
+
values = values0
|
14 |
+
if values:
|
15 |
+
turn_title = values[-1]
|
16 |
+
else:
|
17 |
+
turn_title = response_text[:15] + '...'
|
18 |
+
return turn_title
|
19 |
+
|
20 |
+
|
21 |
+
def get_final_answer(response_text, cli_mode=False):
|
22 |
+
tag = 'final_answer'
|
23 |
+
pattern = fr'<{tag}>(.*?)</{tag}>'
|
24 |
+
values = re.findall(pattern, response_text, re.DOTALL)
|
25 |
+
values0 = values.copy()
|
26 |
+
values = [v.strip() for v in values]
|
27 |
+
values = [v for v in values if v]
|
28 |
+
if len(values) == 0:
|
29 |
+
# then maybe removed too much
|
30 |
+
values = values0
|
31 |
+
if values:
|
32 |
+
if cli_mode:
|
33 |
+
response_text = '\n\nFINAL ANSWER:\n\n' + values[-1] + '\n\n'
|
34 |
+
else:
|
35 |
+
response_text = values[-1]
|
36 |
+
else:
|
37 |
+
response_text = None
|
38 |
+
return response_text
|
39 |
+
|
40 |
+
|
41 |
+
def get_xml_tag_value(response_text, tag, ret_all=True):
|
42 |
+
pattern = fr'<{tag}>(.*?)</{tag}>'
|
43 |
+
values = re.findall(pattern, response_text, re.DOTALL)
|
44 |
+
values0 = values.copy()
|
45 |
+
values = [v.strip() for v in values]
|
46 |
+
values = [v for v in values if v]
|
47 |
+
if len(values) == 0:
|
48 |
+
# then maybe removed too much
|
49 |
+
values = values0
|
50 |
+
if values:
|
51 |
+
if ret_all:
|
52 |
+
ret = values
|
53 |
+
else:
|
54 |
+
ret = [values[-1]]
|
55 |
+
else:
|
56 |
+
ret = []
|
57 |
+
return ret
|