Spaces:
Configuration error
Configuration error
Tristan Thrush
commited on
Commit
β’
90b6f98
1
Parent(s):
013ce7b
synched model memory with conversation, and make sure it is wiped for next hit
Browse files
app.py
CHANGED
@@ -98,6 +98,13 @@ chatbot_4 = ConversationChain(
|
|
98 |
memory=ConversationBufferMemory(ai_prefix="Assistant"),
|
99 |
)
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
demo = gr.Blocks()
|
102 |
|
103 |
with demo:
|
@@ -130,17 +137,17 @@ with demo:
|
|
130 |
response_3 = chatbot_3.predict(input=txt)
|
131 |
response_4 = chatbot_4.predict(input=txt)
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
|
139 |
state["cnt"] += 1
|
140 |
|
141 |
new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
|
142 |
|
143 |
-
state["data"].append({"cnt": state["cnt"], "text": txt, "response_1": response_1, "response_2": response_2, "response_3": response_3, "response_4": response_4,"
|
144 |
state["past_user_inputs"].append(txt)
|
145 |
|
146 |
past_conversation_string = "<br />".join(["<br />".join(["π: " + user_input, "π€: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
|
@@ -150,7 +157,7 @@ with demo:
|
|
150 |
done = state["cnt"] == TOTAL_CNT
|
151 |
state["generated_responses"].append(selected_response)
|
152 |
state["data"][-1]["selected_response"] = selected_response
|
153 |
-
state["data"][-1]["selected_model"] = state["data"][-1]["
|
154 |
if state["cnt"] == TOTAL_CNT:
|
155 |
# Write the HIT data to our local dataset because the worker has
|
156 |
# submitted everything now.
|
@@ -172,6 +179,21 @@ with demo:
|
|
172 |
else:
|
173 |
toggle_final_submit_preview = gr.update(visible=done)
|
174 |
toggle_final_submit = gr.update(visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
text_input = gr.update(visible=False) if done else gr.update(visible=True)
|
176 |
return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
|
177 |
|
|
|
98 |
memory=ConversationBufferMemory(ai_prefix="Assistant"),
|
99 |
)
|
100 |
|
101 |
+
model_id2model = {
|
102 |
+
"google/flan-t5-xl": chatbot_1,
|
103 |
+
"bigscience/bloom": chatbot_2,
|
104 |
+
"bigscience/T0_3B": chatbot_3,
|
105 |
+
"EleutherAI/gpt-j-6B": chatbot_4
|
106 |
+
}
|
107 |
+
|
108 |
demo = gr.Blocks()
|
109 |
|
110 |
with demo:
|
|
|
137 |
response_3 = chatbot_3.predict(input=txt)
|
138 |
response_4 = chatbot_4.predict(input=txt)
|
139 |
|
140 |
+
response2model_id = {}
|
141 |
+
response2model_id[response_1] = chatbot_1.llm.repo_id
|
142 |
+
response2model_id[response_2] = chatbot_2.llm.repo_id
|
143 |
+
response2model_id[response_3] = chatbot_3.llm.repo_id
|
144 |
+
response2model_id[response_4] = chatbot_4.llm.repo_id
|
145 |
|
146 |
state["cnt"] += 1
|
147 |
|
148 |
new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
|
149 |
|
150 |
+
state["data"].append({"cnt": state["cnt"], "text": txt, "response_1": response_1, "response_2": response_2, "response_3": response_3, "response_4": response_4,"response2model_id": response2model_id})
|
151 |
state["past_user_inputs"].append(txt)
|
152 |
|
153 |
past_conversation_string = "<br />".join(["<br />".join(["π: " + user_input, "π€: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
|
|
|
157 |
done = state["cnt"] == TOTAL_CNT
|
158 |
state["generated_responses"].append(selected_response)
|
159 |
state["data"][-1]["selected_response"] = selected_response
|
160 |
+
state["data"][-1]["selected_model"] = state["data"][-1]["response2model_id"][selected_response]
|
161 |
if state["cnt"] == TOTAL_CNT:
|
162 |
# Write the HIT data to our local dataset because the worker has
|
163 |
# submitted everything now.
|
|
|
179 |
else:
|
180 |
toggle_final_submit_preview = gr.update(visible=done)
|
181 |
toggle_final_submit = gr.update(visible=False)
|
182 |
+
|
183 |
+
if done:
|
184 |
+
# Wipe the memory completely because we will be starting a new hit soon.
|
185 |
+
chatbot_1.memory = ConversationBufferMemory(ai_prefix="Assistant")
|
186 |
+
chatbot_2.memory = ConversationBufferMemory(ai_prefix="Assistant")
|
187 |
+
chatbot_3.memory = ConversationBufferMemory(ai_prefix="Assistant")
|
188 |
+
chatbot_4.memory = ConversationBufferMemory(ai_prefix="Assistant")
|
189 |
+
else:
|
190 |
+
# Sync all of the model's memories with the conversation path that
|
191 |
+
# was actually taken.
|
192 |
+
chatbot_1.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
|
193 |
+
chatbot_2.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
|
194 |
+
chatbot_3.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
|
195 |
+
chatbot_4.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
|
196 |
+
|
197 |
text_input = gr.update(visible=False) if done else gr.update(visible=True)
|
198 |
return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
|
199 |
|