lewtun HF staff commited on
Commit
b9997a4
β€’
1 Parent(s): 5e588f0
Files changed (1) hide show
  1. app.py +51 -66
app.py CHANGED
@@ -1,10 +1,13 @@
1
  # Basic example for doing model-in-the-loop dynamic adversarial data collection
2
  # using Gradio Blocks.
 
3
  import json
4
  import os
5
  import threading
 
6
  import uuid
7
  from pathlib import Path
 
8
  from urllib.parse import parse_qs
9
 
10
  import gradio as gr
@@ -17,15 +20,27 @@ from langchain.prompts import load_prompt
17
 
18
  from utils import force_git_push
19
 
20
- # These variables are for storing the mturk HITs in a Hugging Face dataset.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  if Path(".env").is_file():
22
  load_dotenv(".env")
23
  DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
24
  FORCE_PUSH = os.getenv("FORCE_PUSH")
25
  HF_TOKEN = os.getenv("HF_TOKEN")
26
  PROMPT_TEMPLATES = Path("prompt_templates")
27
- # Set env variable for langchain to communicate with Hugging Face Hub
28
- os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF_TOKEN
29
 
30
  DATA_FILENAME = "data.jsonl"
31
  DATA_FILE = os.path.join("data", DATA_FILENAME)
@@ -58,52 +73,24 @@ asynchronous_push(f_stop)
58
  # Now let's run the app!
59
  prompt = load_prompt(PROMPT_TEMPLATES / "openai_chatgpt.json")
60
 
61
- chatbot_1 = ConversationChain(
62
- llm=HuggingFaceHub(
63
- repo_id="google/flan-t5-xl",
64
- model_kwargs={"temperature": 1}
65
- ),
66
- prompt=prompt,
67
- verbose=False,
68
- memory=ConversationBufferMemory(ai_prefix="Assistant"),
69
- )
70
 
71
- chatbot_2 = ConversationChain(
 
72
  llm=HuggingFaceHub(
73
- repo_id="bigscience/bloom",
74
- model_kwargs={"temperature": 0.7}
 
75
  ),
76
  prompt=prompt,
77
  verbose=False,
78
  memory=ConversationBufferMemory(ai_prefix="Assistant"),
79
- )
80
 
81
- chatbot_3 = ConversationChain(
82
- llm=HuggingFaceHub(
83
- repo_id="bigscience/T0_3B",
84
- model_kwargs={"temperature": 1}
85
- ),
86
- prompt=prompt,
87
- verbose=False,
88
- memory=ConversationBufferMemory(ai_prefix="Assistant"),
89
- )
90
 
91
- chatbot_4 = ConversationChain(
92
- llm=HuggingFaceHub(
93
- repo_id="EleutherAI/gpt-j-6B",
94
- model_kwargs={"temperature": 1}
95
- ),
96
- prompt=prompt,
97
- verbose=False,
98
- memory=ConversationBufferMemory(ai_prefix="Assistant"),
99
- )
100
-
101
- model_id2model = {
102
- "google/flan-t5-xl": chatbot_1,
103
- "bigscience/bloom": chatbot_2,
104
- "bigscience/T0_3B": chatbot_3,
105
- "EleutherAI/gpt-j-6B": chatbot_4
106
- }
107
 
108
  demo = gr.Blocks()
109
 
@@ -117,11 +104,9 @@ with demo:
117
  "cnt": 0, "data": [],
118
  "past_user_inputs": [],
119
  "generated_responses": [],
120
- "response_1": "",
121
- "response_2": "",
122
- "response_3": "",
123
- "response_4": "",
124
  }
 
 
125
  state = gr.JSON(state_dict, visible=False)
126
 
127
  gr.Markdown("# RLHF Interface")
@@ -132,26 +117,30 @@ with demo:
132
  # Generate model prediction
133
  def _predict(txt, state):
134
  # TODO: parallelize this!
135
- response_1 = chatbot_1.predict(input=txt)
136
- response_2 = chatbot_2.predict(input=txt)
137
- response_3 = chatbot_3.predict(input=txt)
138
- response_4 = chatbot_4.predict(input=txt)
139
 
140
  response2model_id = {}
141
- response2model_id[response_1] = chatbot_1.llm.repo_id
142
- response2model_id[response_2] = chatbot_2.llm.repo_id
143
- response2model_id[response_3] = chatbot_3.llm.repo_id
144
- response2model_id[response_4] = chatbot_4.llm.repo_id
145
 
146
  state["cnt"] += 1
147
 
148
  new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
149
 
150
- state["data"].append({"cnt": state["cnt"], "text": txt, "response_1": response_1, "response_2": response_2, "response_3": response_3, "response_4": response_4,"response2model_id": response2model_id})
 
 
 
 
 
 
151
  state["past_user_inputs"].append(txt)
152
 
153
  past_conversation_string = "<br />".join(["<br />".join(["πŸ˜ƒ: " + user_input, "πŸ€–: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
154
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True, choices=[response_1, response_2, response_3, response_4], interactive=True, value=response_1), gr.update(value=past_conversation_string), state, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), new_state_md, dummy
155
 
156
  def _select_response(selected_response, state, dummy):
157
  done = state["cnt"] == TOTAL_CNT
@@ -169,7 +158,7 @@ with demo:
169
  past_conversation_string = "<br />".join(["<br />".join(["πŸ˜ƒ: " + user_input, "πŸ€–: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"])])
170
  query = parse_qs(dummy[1:])
171
  if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE":
172
- # It seems that someone is using this app on mturk. We need to
173
  # store the assignmentId in the state before submit_hit_button
174
  # is clicked. We can do this here in _predict. We need to save the
175
  # assignmentId so that the turker can get credit for their HIT.
@@ -182,17 +171,13 @@ with demo:
182
 
183
  if done:
184
  # Wipe the memory completely because we will be starting a new hit soon.
185
- chatbot_1.memory = ConversationBufferMemory(ai_prefix="Assistant")
186
- chatbot_2.memory = ConversationBufferMemory(ai_prefix="Assistant")
187
- chatbot_3.memory = ConversationBufferMemory(ai_prefix="Assistant")
188
- chatbot_4.memory = ConversationBufferMemory(ai_prefix="Assistant")
189
  else:
190
  # Sync all of the model's memories with the conversation path that
191
  # was actually taken.
192
- chatbot_1.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
193
- chatbot_2.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
194
- chatbot_3.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
195
- chatbot_4.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
196
 
197
  text_input = gr.update(visible=False) if done else gr.update(visible=True)
198
  return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
@@ -207,7 +192,7 @@ with demo:
207
  with gr.Column(visible=False) as final_submit:
208
  submit_hit_button = gr.Button("Submit HIT")
209
  with gr.Column(visible=False) as final_submit_preview:
210
- submit_hit_button_preview = gr.Button("Submit Work (preview mode; no mturk HIT credit, but your examples will still be stored)")
211
 
212
  # Button event handlers
213
  get_window_location_search_js = """
@@ -232,7 +217,7 @@ with demo:
232
 
233
  post_hit_js = """
234
  function(state) {
235
- // If there is an assignmentId, then the submitter is on mturk
236
  // and has accepted the HIT. So, we need to submit their HIT.
237
  const form = document.createElement('form');
238
  form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';
 
1
  # Basic example for doing model-in-the-loop dynamic adversarial data collection
2
  # using Gradio Blocks.
3
+ import concurrent.futures
4
  import json
5
  import os
6
  import threading
7
+ import time
8
  import uuid
9
  from pathlib import Path
10
+ from typing import List
11
  from urllib.parse import parse_qs
12
 
13
  import gradio as gr
 
20
 
21
  from utils import force_git_push
22
 
23
+
24
+ def generate_respone(chatbot: ConversationChain, input: str) -> str:
25
+ """Generates a response for a `langchain` chatbot."""
26
+ return chatbot.predict(input=input)
27
+
28
+ def generate_responses(chatbots: List[ConversationChain], inputs: List[str]) -> List[str]:
29
+ """Generates parallel responses for a list of `langchain` chatbots."""
30
+ results = []
31
+ executor = concurrent.futures.ThreadPoolExecutor(max_workers=100)
32
+ for result in executor.map(generate_respone, chatbots, inputs):
33
+ results.append(result)
34
+ return results
35
+
36
+
37
+ # These variables are for storing the MTurk HITs in a Hugging Face dataset.
38
  if Path(".env").is_file():
39
  load_dotenv(".env")
40
  DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
41
  FORCE_PUSH = os.getenv("FORCE_PUSH")
42
  HF_TOKEN = os.getenv("HF_TOKEN")
43
  PROMPT_TEMPLATES = Path("prompt_templates")
 
 
44
 
45
  DATA_FILENAME = "data.jsonl"
46
  DATA_FILE = os.path.join("data", DATA_FILENAME)
 
73
  # Now let's run the app!
74
  prompt = load_prompt(PROMPT_TEMPLATES / "openai_chatgpt.json")
75
 
76
+ # TODO: update this list with better, instruction-trained models
77
+ MODEL_IDS = ["google/flan-t5-xl", "bigscience/T0_3B", "EleutherAI/gpt-j-6B"]
78
+ chatbots = []
 
 
 
 
 
 
79
 
80
+ for model_id in MODEL_IDS:
81
+ chatbots.append(ConversationChain(
82
  llm=HuggingFaceHub(
83
+ repo_id=model_id,
84
+ model_kwargs={"temperature": 1},
85
+ huggingfacehub_api_token=HF_TOKEN,
86
  ),
87
  prompt=prompt,
88
  verbose=False,
89
  memory=ConversationBufferMemory(ai_prefix="Assistant"),
90
+ ))
91
 
 
 
 
 
 
 
 
 
 
92
 
93
+ model_id2model = {chatbot.llm.repo_id: chatbot for chatbot in chatbots}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  demo = gr.Blocks()
96
 
 
104
  "cnt": 0, "data": [],
105
  "past_user_inputs": [],
106
  "generated_responses": [],
 
 
 
 
107
  }
108
+ for idx in range(len(chatbots)):
109
+ state_dict[f"response_{idx+1}"] = ""
110
  state = gr.JSON(state_dict, visible=False)
111
 
112
  gr.Markdown("# RLHF Interface")
 
117
  # Generate model prediction
118
  def _predict(txt, state):
119
  # TODO: parallelize this!
120
+ start = time.time()
121
+ responses = generate_responses(chatbots, [txt] * len(chatbots))
122
+ print(f"Time taken (threading): {time.time() - start} seconds")
123
+
124
 
125
  response2model_id = {}
126
+ for chatbot, response in zip(chatbots, responses):
127
+ response2model_id[response] = chatbot.llm.repo_id
 
 
128
 
129
  state["cnt"] += 1
130
 
131
  new_state_md = f"Inputs remaining in HIT: {state['cnt']}/{TOTAL_CNT}"
132
 
133
+ metadata = {"cnt": state["cnt"], "text": txt}
134
+ for idx, response in enumerate(responses):
135
+ metadata[f"response_{idx + 1}"] = response
136
+
137
+ metadata["response2model_id"] = response2model_id
138
+
139
+ state["data"].append(metadata)
140
  state["past_user_inputs"].append(txt)
141
 
142
  past_conversation_string = "<br />".join(["<br />".join(["πŸ˜ƒ: " + user_input, "πŸ€–: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"] + [""])])
143
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True, choices=responses, interactive=True, value=responses[0]), gr.update(value=past_conversation_string), state, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), new_state_md, dummy
144
 
145
  def _select_response(selected_response, state, dummy):
146
  done = state["cnt"] == TOTAL_CNT
 
158
  past_conversation_string = "<br />".join(["<br />".join(["πŸ˜ƒ: " + user_input, "πŸ€–: " + model_response]) for user_input, model_response in zip(state["past_user_inputs"], state["generated_responses"])])
159
  query = parse_qs(dummy[1:])
160
  if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE":
161
+ # It seems that someone is using this app on MTurk. We need to
162
  # store the assignmentId in the state before submit_hit_button
163
  # is clicked. We can do this here in _predict. We need to save the
164
  # assignmentId so that the turker can get credit for their HIT.
 
171
 
172
  if done:
173
  # Wipe the memory completely because we will be starting a new hit soon.
174
+ for chatbot in chatbots:
175
+ chatbot.memory = ConversationBufferMemory(ai_prefix="Assistant")
 
 
176
  else:
177
  # Sync all of the model's memories with the conversation path that
178
  # was actually taken.
179
+ for chatbot in chatbots:
180
+ chatbot.memory = model_id2model[state["data"][-1]["response2model_id"][selected_response]].memory
 
 
181
 
182
  text_input = gr.update(visible=False) if done else gr.update(visible=True)
183
  return gr.update(visible=False), gr.update(visible=True), text_input, gr.update(visible=False), state, gr.update(value=past_conversation_string), toggle_example_submit, toggle_final_submit, toggle_final_submit_preview,
 
192
  with gr.Column(visible=False) as final_submit:
193
  submit_hit_button = gr.Button("Submit HIT")
194
  with gr.Column(visible=False) as final_submit_preview:
195
+ submit_hit_button_preview = gr.Button("Submit Work (preview mode; no MTurk HIT credit, but your examples will still be stored)")
196
 
197
  # Button event handlers
198
  get_window_location_search_js = """
 
217
 
218
  post_hit_js = """
219
  function(state) {
220
+ // If there is an assignmentId, then the submitter is on MTurk
221
  // and has accepted the HIT. So, we need to submit their HIT.
222
  const form = document.createElement('form');
223
  form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';