tmzh commited on
Commit
aedb933
1 Parent(s): e640a42

refactoring

Browse files
Files changed (1) hide show
  1. app.py +116 -73
app.py CHANGED
@@ -5,7 +5,8 @@ import os
5
  import random
6
  import json
7
  import re
8
- from typing import List
 
9
 
10
  import gradio as gr
11
  import outlines
@@ -13,44 +14,24 @@ import requests
13
  from outlines import models, generate, samplers
14
  from pydantic import BaseModel
15
 
16
-
17
- def merge_games(clues, num_merges=10):
18
- """Generates around 10 merges of words from the given clues.
19
-
20
- Args:
21
- clues: A list of clues, where each clue is a list containing the words, the answer, and the explanation.
22
- num_merges: The approximate number of merges to generate (default: 10).
23
-
24
- Returns:
25
- A list of tuples, where each tuple contains the merged words and the indices of the selected rows.
26
- """
27
-
28
- merges = []
29
- while len(merges) < num_merges:
30
- num_rows = random.choice([3, 4])
31
- selected_rows = random.sample(range(len(clues)), num_rows)
32
- merged_words = " ".join([word for row in [clues[i][0] for i in selected_rows] for word in row])
33
- if len(merged_words.split()) in [8, 9]:
34
- merges.append((merged_words.split(), selected_rows))
35
-
36
- return merges
37
-
38
 
39
  class Clue(BaseModel):
40
  word: str
41
  explanation: str
42
 
43
-
44
  class Group(BaseModel):
45
  words: List[str]
46
  clue: str
47
  explanation: str
48
 
49
-
50
  class Groups(BaseModel):
51
  groups: List[Group]
52
 
53
-
54
  example_clues = [
55
  (['ARROW', 'TIE', 'HONOR'], 'BOW', 'such as a bow and arrow, a bow tie, or a bow as a sign of honor'),
56
  (['DOG', 'TREE'], 'BARK', 'such as the sound a dog makes, or a tree is made of bark'),
@@ -81,8 +62,36 @@ example_clues = [
81
  'such as the Arctic being home to seals, or shutting a seal on an envelope, or a stamp being a type of seal'),
82
  ]
83
 
 
 
 
84
 
85
- def group_words(words):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  @outlines.prompt
87
  def chat_group_template(system_prompt, query, history=[]):
88
  '''<s><|system|>
@@ -99,10 +108,11 @@ def group_words(words):
99
  '''
100
 
101
  grouping_system_prompt = ("You are an assistant for the game Codenames. Your task is to help players by grouping a "
102
- "given group of secrets into 3 to 4 groups. Each group should consist of secrets that "
103
- "share a common theme or other word connections such as homonym, hypernyms or synonyms")
 
104
  example_groupings = []
105
- merges = merge_games(example_clues, 5)
106
  for merged_words, indices in merges:
107
  groups = [{
108
  "secrets": example_clues[i][0],
@@ -111,20 +121,25 @@ def group_words(words):
111
  } for i in indices]
112
  example_groupings.append((merged_words, json.dumps(groups, separators=(',', ':'))))
113
 
114
- prompt = chat_group_template(grouping_system_prompt, words, example_groupings)
115
  sampler = samplers.greedy()
116
  generator = generate.json(model, Groups, sampler)
117
 
118
- print("Grouping words:", words)
119
- generations = generator(
120
- prompt,
121
- max_tokens=500
122
- )
123
- print("Got groupings: ", generations)
124
  return generations.groups
125
 
 
 
 
 
 
 
126
 
127
- def generate_clues(group):
 
 
128
  @outlines.prompt
129
  def chat_clue_template(system, query, history=[]):
130
  '''<s><|system|>
@@ -140,23 +155,29 @@ def generate_clues(group):
140
  <|assistant|>
141
  '''
142
 
143
- clue_system_prompt = ("You are a codenames game companion. Your task is to give a single word clue related to "
144
- "a given group of words. You will only respond with a single word clue. Compound words are "
145
  "allowed. Do not include the word 'Clue'. Do not provide explanations or notes.")
146
 
147
  prompt = chat_clue_template(clue_system_prompt, group, example_clues)
148
- # sampler = samplers.greedy()
149
  sampler = samplers.multinomial(2, top_k=10)
150
  generator = generate.json(model, Clue, sampler)
151
  generations = generator(prompt, max_tokens=100)
152
- print("Got clues: ", generations)
153
  return generations[0]
154
 
155
-
156
- def jpeg_with_target_size(im, target):
157
- """Return the image as JPEG with the given name at best quality that makes less than "target" bytes
158
 
159
  https://stackoverflow.com/a/52281257
 
 
 
 
 
 
 
160
  """
161
  # Min and Max quality
162
  qmin, qmax = 25, 96
@@ -167,65 +188,90 @@ def jpeg_with_target_size(im, target):
167
 
168
  # Encode into memory and get size
169
  buffer = io.BytesIO()
170
- im.save(buffer, format="JPEG", quality=m)
171
  s = buffer.getbuffer().nbytes
172
 
173
- if s <= target:
174
  qacc = m
175
  qmin = m + 1
176
- elif s > target:
177
  qmax = m - 1
178
 
179
  # Write to disk at the defined quality
180
  if qacc > -1:
181
  image_byte_array = io.BytesIO()
182
- print("Acceptable quality", im, im.format, f"{im.size}x{im.mode}")
183
- im.save(image_byte_array, format='JPEG', quality=qacc)
184
  return image_byte_array.getvalue()
185
 
 
 
 
186
 
187
- def process_image(img):
188
- # Resize the image
189
- max_size = (1024, 1024)
190
- img.thumbnail(max_size)
191
 
192
- image_byte_array = jpeg_with_target_size(img, 180_000)
 
 
 
 
193
  image_b64 = base64.b64encode(image_byte_array).decode()
194
 
195
- invoke_url = "https://ai.api.nvidia.com/v1/vlm/microsoft/phi-3-vision-128k-instruct"
196
- stream = False
197
-
198
- if os.environ.get("NVIDIA_API_KEY", "").startswith("nvapi-"):
199
- print("Valid NVIDIA_API_KEY already in the environment. Delete to reset")
200
-
201
  headers = {
202
  "Authorization": f"Bearer {os.environ.get('NVIDIA_API_KEY', '')}",
203
- "Accept": "text/event-stream" if stream else "application/json"
204
  }
205
 
206
  payload = {
207
  "messages": [
208
  {
209
  "role": "user",
210
- "content": f'Identify the words in this game of Codenames. Provide only a list of words. Provide the '
211
- f'words in capital letters only. <img src="data:image/png;base64,{image_b64}" />'
212
  }
213
  ],
214
  "max_tokens": 512,
215
  "temperature": 0.1,
216
  "top_p": 0.70,
217
- "stream": stream
218
  }
219
 
220
- response = requests.post(invoke_url, headers=headers, json=payload)
221
  if response.ok:
222
  print(response.json())
223
- # Define the pattern to match uppercase words separated by commas
224
  pattern = r'[A-Z]+(?:\s+[A-Z]+)?'
225
  words = re.findall(pattern, response.json()['choices'][0]['message']['content'])
226
-
227
  return gr.update(choices=words, value=words)
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  if __name__ == '__main__':
231
  with gr.Blocks() as demo:
@@ -235,10 +281,7 @@ if __name__ == '__main__':
235
  with gr.Row():
236
  game_image = gr.Image(type="pil")
237
  word_list_input = gr.Dropdown(label="Enter list of words (comma separated)",
238
- choices='WEREWOLF, CHAIN, MOSQUITO, CRAFT, RANCH, LIP, VALENTINE, CLOUD, '
239
- 'BEARD, BUNK, SECOND, SADDLE, BUCKET, JAIL, ANT, POCKET, LACE, '
240
- 'BREAK, CUCKOO, FLAT, NIL, TIN, CHERRY, CHRISTMAS, MOSES, '
241
- 'TEAM'.split(', '),
242
  multiselect=True,
243
  interactive=True)
244
 
@@ -277,7 +320,7 @@ if __name__ == '__main__':
277
 
278
  def generate_clues_callback(group):
279
  print("Generating clues: ", group)
280
- g = generate_clues(group)
281
  return gr.update(value=g.word, info=g.explanation)
282
 
283
 
 
5
  import random
6
  import json
7
  import re
8
+ from typing import List, Tuple
9
+ import PIL
10
 
11
  import gradio as gr
12
  import outlines
 
14
  from outlines import models, generate, samplers
15
  from pydantic import BaseModel
16
 
17
+ # Constants
18
+ MAX_IMAGE_SIZE = (1024, 1024)
19
+ TARGET_IMAGE_SIZE = 180_000
20
+ NVIDIA_API_URL = "https://ai.api.nvidia.com/v1/vlm/microsoft/phi-3-vision-128k-instruct"
21
+ MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  class Clue(BaseModel):
24
  word: str
25
  explanation: str
26
 
 
27
  class Group(BaseModel):
28
  words: List[str]
29
  clue: str
30
  explanation: str
31
 
 
32
  class Groups(BaseModel):
33
  groups: List[Group]
34
 
 
35
  example_clues = [
36
  (['ARROW', 'TIE', 'HONOR'], 'BOW', 'such as a bow and arrow, a bow tie, or a bow as a sign of honor'),
37
  (['DOG', 'TREE'], 'BARK', 'such as the sound a dog makes, or a tree is made of bark'),
 
62
  'such as the Arctic being home to seals, or shutting a seal on an envelope, or a stamp being a type of seal'),
63
  ]
64
 
65
+ def create_random_word_groups(clues: List[Tuple[List[str], str, str]], target_groups: int = 10) -> List[Tuple[List[str], List[int]]]:
66
+ """
67
+ Creates approximately 'target_groups' random groups of words from the given clues.
68
 
69
+ Args:
70
+ clues: A list of clues, where each clue is a tuple (words, answer, explanation).
71
+ target_groups: The desired number of groups to create.
72
+
73
+ Returns:
74
+ A list of tuples, each containing a list of merged words and their corresponding indices.
75
+ """
76
+ groups = []
77
+ while len(groups) < target_groups:
78
+ num_rows = random.choice([3, 4])
79
+ selected_indices = random.sample(range(len(clues)), num_rows)
80
+ merged_words = [word for row in [clues[i][0] for i in selected_indices] for word in row]
81
+ if len(merged_words) in [8, 9]:
82
+ groups.append((merged_words, selected_indices))
83
+ return groups
84
+
85
+ def group_words(word_list: List[str]) -> List[Group]:
86
+ """
87
+ Groups the given words into 3 to 4 thematic groups.
88
+
89
+ Args:
90
+ word_list: A list of words to be grouped.
91
+
92
+ Returns:
93
+ A list of Group objects representing the grouped words.
94
+ """
95
  @outlines.prompt
96
  def chat_group_template(system_prompt, query, history=[]):
97
  '''<s><|system|>
 
108
  '''
109
 
110
  grouping_system_prompt = ("You are an assistant for the game Codenames. Your task is to help players by grouping a "
111
+ "given set of words into 3 to 4 groups. Each group should consist of words that "
112
+ "share a common theme or other word connections such as homonyms, hypernyms, or synonyms.")
113
+
114
  example_groupings = []
115
+ merges = create_random_word_groups(example_clues, 5)
116
  for merged_words, indices in merges:
117
  groups = [{
118
  "secrets": example_clues[i][0],
 
121
  } for i in indices]
122
  example_groupings.append((merged_words, json.dumps(groups, separators=(',', ':'))))
123
 
124
+ prompt = chat_group_template(grouping_system_prompt, word_list, example_groupings)
125
  sampler = samplers.greedy()
126
  generator = generate.json(model, Groups, sampler)
127
 
128
+ print(f"Grouping words: {word_list}")
129
+ generations = generator(prompt, max_tokens=500)
130
+ print(f"Generated groupings: {generations}")
 
 
 
131
  return generations.groups
132
 
133
+ def generate_clue(group: List[str]) -> Clue:
134
+ """
135
+ Generates a single-word clue for the given group of words.
136
+
137
+ Args:
138
+ group: A list of words to generate a clue for.
139
 
140
+ Returns:
141
+ A Clue object containing the generated word and its explanation.
142
+ """
143
  @outlines.prompt
144
  def chat_clue_template(system, query, history=[]):
145
  '''<s><|system|>
 
155
  <|assistant|>
156
  '''
157
 
158
+ clue_system_prompt = ("You are a Codenames game companion. Your task is to give a single word clue related to "
159
+ "a given group of words. Respond with a single word clue only. Compound words are "
160
  "allowed. Do not include the word 'Clue'. Do not provide explanations or notes.")
161
 
162
  prompt = chat_clue_template(clue_system_prompt, group, example_clues)
 
163
  sampler = samplers.multinomial(2, top_k=10)
164
  generator = generate.json(model, Clue, sampler)
165
  generations = generator(prompt, max_tokens=100)
166
+ print(f"Generated clues: {generations}")
167
  return generations[0]
168
 
169
+ def compress_image_to_jpeg(image: 'PIL.Image', target_size: int) -> bytes:
170
+ """
171
+ Compresses the image to JPEG format with the best quality that fits within the target size.
172
 
173
  https://stackoverflow.com/a/52281257
174
+
175
+ Args:
176
+ image: The PIL Image object to compress.
177
+ target_size: The target file size in bytes.
178
+
179
+ Returns:
180
+ The compressed image as bytes.
181
  """
182
  # Min and Max quality
183
  qmin, qmax = 25, 96
 
188
 
189
  # Encode into memory and get size
190
  buffer = io.BytesIO()
191
+ image.save(buffer, format="JPEG", quality=m)
192
  s = buffer.getbuffer().nbytes
193
 
194
+ if s <= target_size:
195
  qacc = m
196
  qmin = m + 1
197
+ elif s > target_size:
198
  qmax = m - 1
199
 
200
  # Write to disk at the defined quality
201
  if qacc > -1:
202
  image_byte_array = io.BytesIO()
203
+ print("Acceptable quality", image, image.format, f"{image.size}x{image.mode}")
204
+ image.save(image_byte_array, format='JPEG', quality=qacc)
205
  return image_byte_array.getvalue()
206
 
207
+ def process_image(img: 'PIL.Image') -> gr.update:
208
+ """
209
+ Processes the uploaded image to detect words for the Codenames game.
210
 
211
+ Args:
212
+ img: The uploaded PIL Image object.
 
 
213
 
214
+ Returns:
215
+ A gradio update object with the detected words.
216
+ """
217
+ img.thumbnail(MAX_IMAGE_SIZE)
218
+ image_byte_array = compress_image_to_jpeg(img, TARGET_IMAGE_SIZE)
219
  image_b64 = base64.b64encode(image_byte_array).decode()
220
 
 
 
 
 
 
 
221
  headers = {
222
  "Authorization": f"Bearer {os.environ.get('NVIDIA_API_KEY', '')}",
223
+ "Accept": "application/json"
224
  }
225
 
226
  payload = {
227
  "messages": [
228
  {
229
  "role": "user",
230
+ "content": f'Identify the words in this game of Codenames. Provide only a list of words in capital letters. <img src="data:image/png;base64,{image_b64}" />'
 
231
  }
232
  ],
233
  "max_tokens": 512,
234
  "temperature": 0.1,
235
  "top_p": 0.70,
236
+ "stream": False
237
  }
238
 
239
+ response = requests.post(NVIDIA_API_URL, headers=headers, json=payload)
240
  if response.ok:
241
  print(response.json())
 
242
  pattern = r'[A-Z]+(?:\s+[A-Z]+)?'
243
  words = re.findall(pattern, response.json()['choices'][0]['message']['content'])
 
244
  return gr.update(choices=words, value=words)
245
 
246
+ def pad_or_truncate_groups(groups: List[Group], target_length: int = 4) -> List[Group]:
247
+ """
248
+ Ensures the list of groups has exactly target_length elements, padding with empty Groups if necessary.
249
+
250
+ Args:
251
+ groups: The list of Group objects to pad or truncate.
252
+ target_length: The desired length of the list.
253
+
254
+ Returns:
255
+ A list of Group objects with the specified length.
256
+ """
257
+ truncated_groups = groups[:target_length]
258
+ return truncated_groups + [Group(words=[], clue='', explanation='') for _ in range(target_length - len(truncated_groups))]
259
+
260
+ def group_words_callback(words: List[str]) -> List[gr.update]:
261
+ """
262
+ Callback function to group the selected words.
263
+
264
+ Args:
265
+ words: A list of words to group.
266
+
267
+ Returns:
268
+ A list of gradio update objects for each group input.
269
+ """
270
+ groups = group_words(words)
271
+ groups = pad_or_truncate_groups(groups, 4)
272
+ print(f"Generated groups: {groups}")
273
+ return [gr.update(value=group.words, choices=group.words, info=group.explanation) for group in groups]
274
+
275
 
276
  if __name__ == '__main__':
277
  with gr.Blocks() as demo:
 
281
  with gr.Row():
282
  game_image = gr.Image(type="pil")
283
  word_list_input = gr.Dropdown(label="Enter list of words (comma separated)",
284
+ choices=[],
 
 
 
285
  multiselect=True,
286
  interactive=True)
287
 
 
320
 
321
  def generate_clues_callback(group):
322
  print("Generating clues: ", group)
323
+ g = generate_clue(group)
324
  return gr.update(value=g.word, info=g.explanation)
325
 
326