Tristan Thrush commited on
Commit
90b6f98
β€’
1 Parent(s): 013ce7b

synched model memory with conversation, and make sure it is wiped for next hit

Browse files
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -98,6 +98,13 @@ chatbot_4 = ConversationChain(
98
  memory=ConversationBufferMemory(ai_prefix="Assistant"),
99
  )
100
 
 
 
 
 
 
 
 
101
  demo = gr.Blocks()
102
 
103
  with demo:
@@ -130,17 +137,17 @@ with demo:
130
  response_3 = chatbot_3.predict(input=txt)
131
  response_4 = chatbot_4.predict(input=txt)
132
 
133
- response2model = {}
134
- response2model[response_1] = chatbot_1.llm.repo_id
135
- response2model[response_2] = chatbot_2.llm.repo_id
136
- response2model[response_3] = chatbot_3.llm.repo_id
137
- response2model[response_4] = chatbot_4.llm.repo_id
138
 
139
  state["cnt"] += 1
140
 
141
  new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
142
 
143
- state["data"].append({"cnt": state["cnt"], "text": txt, "response_1": response_1, "response_2": response_2, "response_3": response_3, "response_4": response_4,"response2model": response2model})
144
  state["past_user_inputs"].append(txt)
145
 
146
  past_conversation_string = "<br />".join(["<br />".join(["πŸ˜ƒ: " + user_input, "πŸ€–: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
@@ -150,7 +157,7 @@ with demo:
150
  done = state["cnt"] == TOTAL_CNT
151
  state["generated_responses"].append(selected_response)
152
  state["data"][-1]["selected_response"] = selected_response
153
- state["data"][-1]["selected_model"] = state["data"][-1]["response2model"][selected_response]
154
  if state["cnt"] == TOTAL_CNT:
155
  # Write the HIT data to our local dataset because the worker has
156
  # submitted everything now.
@@ -172,6 +179,21 @@ with demo:
172
  else:
173
  toggle_final_submit_preview = gr.update(visible=done)
174
  toggle_final_submit = gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  text_input = gr.update(visible=False) if done else gr.update(visible=True)
176
  return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
177
 
 
98
  memory=ConversationBufferMemory(ai_prefix="Assistant"),
99
  )
100
 
101
+ model_id2model = {
102
+ "google/flan-t5-xl": chatbot_1,
103
+ "bigscience/bloom": chatbot_2,
104
+ "bigscience/T0_3B": chatbot_3,
105
+ "EleutherAI/gpt-j-6B": chatbot_4
106
+ }
107
+
108
  demo = gr.Blocks()
109
 
110
  with demo:
 
137
  response_3 = chatbot_3.predict(input=txt)
138
  response_4 = chatbot_4.predict(input=txt)
139
 
140
+ response2model_id = {}
141
+ response2model_id[response_1] = chatbot_1.llm.repo_id
142
+ response2model_id[response_2] = chatbot_2.llm.repo_id
143
+ response2model_id[response_3] = chatbot_3.llm.repo_id
144
+ response2model_id[response_4] = chatbot_4.llm.repo_id
145
 
146
  state["cnt"] += 1
147
 
148
  new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
149
 
150
+ state["data"].append({"cnt": state["cnt"], "text": txt, "response_1": response_1, "response_2": response_2, "response_3": response_3, "response_4": response_4,"response2model_id": response2model_id})
151
  state["past_user_inputs"].append(txt)
152
 
153
  past_conversation_string = "<br />".join(["<br />".join(["πŸ˜ƒ: " + user_input, "πŸ€–: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
 
157
  done = state["cnt"] == TOTAL_CNT
158
  state["generated_responses"].append(selected_response)
159
  state["data"][-1]["selected_response"] = selected_response
160
+ state["data"][-1]["selected_model"] = state["data"][-1]["response2model_id"][selected_response]
161
  if state["cnt"] == TOTAL_CNT:
162
  # Write the HIT data to our local dataset because the worker has
163
  # submitted everything now.
 
179
  else:
180
  toggle_final_submit_preview = gr.update(visible=done)
181
  toggle_final_submit = gr.update(visible=False)
182
+
183
+ if done:
184
+ # Wipe the memory completely because we will be starting a new hit soon.
185
+ chatbot_1.memory = ConversationBufferMemory(ai_prefix="Assistant")
186
+ chatbot_2.memory = ConversationBufferMemory(ai_prefix="Assistant")
187
+ chatbot_3.memory = ConversationBufferMemory(ai_prefix="Assistant")
188
+ chatbot_4.memory = ConversationBufferMemory(ai_prefix="Assistant")
189
+ else:
190
+ # Sync all of the model's memories with the conversation path that
191
+ # was actually taken.
192
+ chatbot_1.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
193
+ chatbot_2.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
194
+ chatbot_3.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
195
+ chatbot_4.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
196
+
197
  text_input = gr.update(visible=False) if done else gr.update(visible=True)
198
  return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
199