Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -186,9 +186,85 @@ def click_on_display(language_instruction, grounding_texts, sketch_pad,
|
|
186 |
|
187 |
return gen_images + [state]
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
|
190 |
loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
|
191 |
state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
if 'boxes' not in state:
|
193 |
state['boxes'] = []
|
194 |
boxes = state['boxes']
|
|
|
186 |
|
187 |
return gen_images + [state]
|
188 |
|
189 |
+
def Pharse2idx(prompt, phrases):
|
190 |
+
def match(prompt_words: List[str], phrase_words: List[str]):
|
191 |
+
if prompt_words == phrase_words:
|
192 |
+
return True
|
193 |
+
for prompt_word, phrase_word in zip(prompt_words, phrase_words):
|
194 |
+
if prompt_word != phrase_word and prompt_word != phrase_word+'s' and prompt_word != phrase_word+'es':
|
195 |
+
return False
|
196 |
+
return True
|
197 |
+
phrases = [x.replace('_', ' ') for x in phrases.split('; ')]
|
198 |
+
print(phrases)
|
199 |
+
object_positions = []
|
200 |
+
for punc in [',', '.', ';', ':', '?', '!']:
|
201 |
+
prompt = prompt.replace(punc, ' '+punc)
|
202 |
+
words = prompt.split()
|
203 |
+
|
204 |
+
for phrase in phrases:
|
205 |
+
phrase_words = phrase.split()
|
206 |
+
positions = []
|
207 |
+
|
208 |
+
for i in range(len(words) - len(phrase_words) + 1):
|
209 |
+
if match(words[i:i + len(phrase_words)], phrase_words):
|
210 |
+
positions += list(range(i+1, i + len(phrase_words)+1))
|
211 |
+
if positions == []:
|
212 |
+
print(prompt)
|
213 |
+
print(phrases)
|
214 |
+
return None
|
215 |
+
object_positions.append(positions)
|
216 |
+
print(object_positions)
|
217 |
+
return object_positions
|
218 |
+
|
219 |
+
|
220 |
def generate(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
|
221 |
loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
|
222 |
state):
|
223 |
+
# language_inst: prompt; grounding_texts: phrases
|
224 |
+
if 'boxes' not in state:
|
225 |
+
state['boxes'] = []
|
226 |
+
boxes = state['boxes']
|
227 |
+
|
228 |
+
print('grounding texts:', grounding_texts)
|
229 |
+
phrases = grounding_texts
|
230 |
+
|
231 |
+
grounding_texts = [x.strip() for x in grounding_texts.split(';')]
|
232 |
+
|
233 |
+
# assert len(boxes) == len(grounding_texts)
|
234 |
+
if len(boxes) != len(grounding_texts):
|
235 |
+
if len(boxes) < len(grounding_texts):
|
236 |
+
raise ValueError("""The number of boxes should be equal to the number of grounding objects.
|
237 |
+
Number of boxes drawn: {}, number of grounding tokens: {}.
|
238 |
+
Please draw boxes accordingly on the sketch pad.""".format(len(boxes), len(grounding_texts)))
|
239 |
+
grounding_texts = grounding_texts + [""] * (len(boxes) - len(grounding_texts))
|
240 |
+
|
241 |
+
boxes = (np.asarray(boxes) / 512).tolist()
|
242 |
+
boxes = [[box] for box in boxes]
|
243 |
+
grounding_instruction = json.dumps({obj: box for obj, box in zip(grounding_texts, boxes)})
|
244 |
+
language_instruction_list = language_instruction.strip('.').split(' ')
|
245 |
+
object_positions = []
|
246 |
+
for obj in grounding_texts:
|
247 |
+
obj_position = []
|
248 |
+
for word in obj.split(' '):
|
249 |
+
obj_first_index = language_instruction_list.index(word) + 1
|
250 |
+
obj_position.append(obj_first_index)
|
251 |
+
object_positions.append(obj_position)
|
252 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
253 |
+
|
254 |
+
object_positions = Pharse2idx(language_instruction, phrases)
|
255 |
+
|
256 |
+
gen_images = inference(device, unet, vae, tokenizer, text_encoder, language_instruction, boxes, object_positions, batch_size, loss_scale, loss_threshold, max_iter, max_step, rand_seed, guidance_scale)
|
257 |
+
|
258 |
+
blank_samples = batch_size % 2 if batch_size > 1 else 0
|
259 |
+
gen_images = [gr.Image.update(value=x, visible=True) for i, x in enumerate(gen_images)] \
|
260 |
+
+ [gr.Image.update(value=None, visible=True) for _ in range(blank_samples)] \
|
261 |
+
+ [gr.Image.update(value=None, visible=False) for _ in range(4 - batch_size - blank_samples)]
|
262 |
+
|
263 |
+
return gen_images + [state]
|
264 |
+
|
265 |
+
def generate_legacy(unet, vae, tokenizer, text_encoder, language_instruction, grounding_texts, sketch_pad,
|
266 |
+
loss_threshold, guidance_scale, batch_size, rand_seed, max_step, loss_scale, max_iter,
|
267 |
+
state):
|
268 |
if 'boxes' not in state:
|
269 |
state['boxes'] = []
|
270 |
boxes = state['boxes']
|