adymaharana commited on
Commit
908bed5
1 Parent(s): 1cac669

fp16 version

Browse files
app.py CHANGED
@@ -6,9 +6,14 @@ from dalle.models import StoryDalle
6
  import argparse
7
  from PIL import Image
8
  from torchvision.utils import save_image
 
9
  import tensorflow_hub as hub
10
  import gdown
 
 
11
 
 
 
12
 
13
  source_frame_paths = {
14
  'Pororo': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_2/Pororo_ENGLISH1_2_ep6/12.png',
@@ -23,6 +28,51 @@ source_frame_paths = {
23
  }
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
27
  mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
28
  std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
@@ -66,9 +116,10 @@ def save_story_results(images, video_len=4, n_candidates=1, mask=None):
66
 
67
 
68
  def main(args):
 
69
  #device = 'cuda:0'
70
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
71
- #device = torch.device('cpu')
72
 
73
  model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
74
 
@@ -81,6 +132,9 @@ def main(args):
81
  #assert os.path.exists("./ckpt/25.pth")
82
  gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
83
 
 
 
 
84
  if args.debug:
85
  model = None
86
  embed = None
@@ -88,13 +142,20 @@ def main(args):
88
  model, config = StoryDalle.from_pretrained(args)
89
  model.tokenizer.add_tokens(['pororo', 'loopy', 'eddy', 'harry', 'poby', 'tongtong', 'crong', 'rody', 'petty'])
90
  model.eval()
91
- model.to(device=device)
 
 
 
 
 
 
 
 
 
 
 
92
  embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5")
93
 
94
- if model.config.story.condition:
95
- for i in range(len(model.cross_attention_layers)):
96
- model.cross_attention_layers[i].to(device)
97
- print("Cross-attention layers are in cuda:", next(model.cross_attention_layers[0].parameters()).is_cuda)
98
 
99
  valid_transform = transforms.Compose(
100
  [transforms.Resize(config.dataset.image_resolution),
@@ -103,6 +164,8 @@ def main(args):
103
  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
104
  )
105
 
 
 
106
  #torch.save(model, './ckpt/checkpoint.pt')
107
  #sys.exit()
108
 
@@ -110,32 +173,62 @@ def main(args):
110
  supercondition=False):
111
 
112
  if not args.debug:
113
- captions = [caption_1, caption_2, caption_3, caption_4]
 
 
 
 
 
 
 
 
114
  mask = [1 if caption != '' else 0 for caption in captions]
 
 
 
 
115
  print(captions, mask, source, n_candidates)
 
 
 
116
  for i, caption in enumerate(captions):
117
  if caption == "":
118
- captions[i] = "Pororo is reading a book."
 
119
  tokens = [model.tokenizer.encode(caption) for caption in captions]
120
  texts = torch.stack([torch.LongTensor(token.ids) for token in tokens]).unsqueeze(0)
121
  sent_embeds = torch.tensor(embed(captions).numpy())
122
- # sent_embeds = torch.tensor(description_vecs[source_frame_paths[source].
123
- # replace('/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/', '')[:-4]][0]).unsqueeze(0).repeat(4, 1)
124
-
125
  src_image = valid_transform(Image.open('./demo/%s.png' % source).convert('RGB'))
126
 
127
  stories = []
128
  with torch.no_grad():
129
  for i in range(texts.shape[0]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
131
- sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
132
- prompt=None, n_candidates=n_candidates).cpu()
133
  stories.append(pixels)
134
-
135
  img = save_story_results(stories, video_len=4, n_candidates=n_candidates, mask=mask)
136
- save_image(img, "gradio_demo_pororo.png", normalize=True)
 
 
 
137
 
138
- return "gradio_demo_pororo.png"
139
 
140
  with gr.Blocks(css='#output {width:750px; height:750px; float:left;}') as demo:
141
  gr.Markdown('''
@@ -170,7 +263,7 @@ def main(args):
170
  Here are some examples of generated visual stories for the above-mentioned settings.
171
 
172
  <p align="center">
173
- <img src="file/demo_pororo_good.png" width="1000">
174
  </p>
175
 
176
  Due to the small training dataset size for story visualization, the model has poor generalization to some unseen settings. The model struggles to generate coherent images in the following scenarios.
@@ -236,10 +329,11 @@ def main(args):
236
  \[4\] Sharma, Piyush, et al. "Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2018.
237
  ''')
238
 
239
- demo.launch(share=True)
240
 
241
 
242
  if __name__ == "__main__":
 
243
  args_list = ['--model_name_or_path', './ckpt/25.pth',
244
  '--prefix_model_name_or_path', './1.3B/',
245
  '--dataset_name', 'pororo',
@@ -351,6 +445,7 @@ if __name__ == "__main__":
351
  )
352
 
353
  parser.add_argument("--debug", action="store_true", help="Whether to debug the demo.")
 
354
 
355
  args = parser.parse_args(args_list)
356
 
 
6
  import argparse
7
  from PIL import Image
8
  from torchvision.utils import save_image
9
+ import tensorflow as tf
10
  import tensorflow_hub as hub
11
  import gdown
12
+ from allennlp.predictors.predictor import Predictor
13
+ import random
14
 
15
+ torch.set_grad_enabled(False)
16
+ tf.config.set_visible_devices([], 'GPU') # setting Tensorflow's GPU visibility to None to constraing embedding model to CPU
17
 
18
  source_frame_paths = {
19
  'Pororo': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_2/Pororo_ENGLISH1_2_ep6/12.png',
 
28
  }
29
 
30
 
31
+ def get_span_words(span, document):
32
+ return ' '.join(document[span[0]:span[1]+1])
33
+
34
+
35
+ def print_clusters(prediction):
36
+ document, clusters = prediction['document'], prediction['clusters']
37
+ for cluster in clusters:
38
+ print(get_span_words(cluster[0], document) + ': ', end='')
39
+ print(f"[{'; '.join([get_span_words(span, document) for span in cluster])}]")
40
+
41
+
42
+ def resolve_coref(captions, captions_mask, coref_predictor):
43
+ sent_counts = []
44
+ doc = ''
45
+ for cap, mask in zip(captions, captions_mask):
46
+ if mask == 0:
47
+ sent_counts.append(0)
48
+ else:
49
+ print(cap)
50
+ count = len([c.strip() for c in cap.split('.') if c.strip()])
51
+ sent_counts.append(count)
52
+ doc += cap + ' '
53
+
54
+ # print(doc)
55
+
56
+ doc = doc.strip()
57
+ resolved_doc = coref_predictor.coref_resolved(doc)
58
+ # print(resolved_doc)
59
+ # print(sent_counts)
60
+
61
+ sents = resolved_doc.split('. ')
62
+ resolved_captions = []
63
+ for i, (count, mask) in enumerate(zip(sent_counts, captions_mask)):
64
+ if mask == 0:
65
+ resolved_captions.append('')
66
+ else:
67
+ new_cap = '. '.join(sents[sum(sent_counts[:i]):sum(sent_counts[:i]) + count])
68
+ new_cap = new_cap.strip()
69
+ if new_cap[-1] not in ['!', '?', '.']:
70
+ new_cap += '.'
71
+ resolved_captions.append(new_cap)
72
+
73
+ return resolved_captions
74
+
75
+
76
  def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
77
  mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
78
  std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
 
116
 
117
 
118
  def main(args):
119
+
120
  #device = 'cuda:0'
121
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
122
+ # device = torch.device('cpu')
123
 
124
  model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
125
 
 
132
  #assert os.path.exists("./ckpt/25.pth")
133
  gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
134
 
135
+ coref_model_url = 'https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz'
136
+ coref_predictor = Predictor.from_path(coref_model_url)
137
+
138
  if args.debug:
139
  model = None
140
  embed = None
 
142
  model, config = StoryDalle.from_pretrained(args)
143
  model.tokenizer.add_tokens(['pororo', 'loopy', 'eddy', 'harry', 'poby', 'tongtong', 'crong', 'rody', 'petty'])
144
  model.eval()
145
+ # split_model into CPU and GPU
146
+ if args.split_memory:
147
+ model.stage2.to(device=device)
148
+ model.story_linear.to(device=device)
149
+ model.story_block.to(device=device)
150
+ else:
151
+ model.to(device=device)
152
+ if model.config.story.condition:
153
+ for i in range(len(model.cross_attention_layers)):
154
+ model.cross_attention_layers[i].to(device)
155
+ print("Cross-attention layers are in cuda:", next(model.cross_attention_layers[0].parameters()).is_cuda)
156
+
157
  embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5")
158
 
 
 
 
 
159
 
160
  valid_transform = transforms.Compose(
161
  [transforms.Resize(config.dataset.image_resolution),
 
164
  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
165
  )
166
 
167
+ print("Model is in ", model.device)
168
+
169
  #torch.save(model, './ckpt/checkpoint.pt')
170
  #sys.exit()
171
 
 
173
  supercondition=False):
174
 
175
  if not args.debug:
176
+
177
+ suffix = random.randint(0, 1000)
178
+ img_file_path = "./demo/images/gradio_demo_pororo_%s.png" % suffix
179
+ txt_file_path = "./demo/texts/gradio_demo_pororo_%s.txt" % suffix
180
+
181
+ captions = [caption_1.strip(), caption_2.strip(), caption_3.strip(), caption_4.strip()]
182
+ for i in range(len(captions)):
183
+ if captions[i][-1] not in ['!', '?', '.']:
184
+ captions[i] = captions[i] + '.'
185
  mask = [1 if caption != '' else 0 for caption in captions]
186
+
187
+ with open(txt_file_path, 'w') as f:
188
+ f.write('\n'.join(captions))
189
+
190
  print(captions, mask, source, n_candidates)
191
+ captions = resolve_coref(captions, mask, coref_predictor)
192
+ print(captions)
193
+
194
  for i, caption in enumerate(captions):
195
  if caption == "":
196
+ captions[i] = "Pororo is reading a book." # filler for shorter captions
197
+
198
  tokens = [model.tokenizer.encode(caption) for caption in captions]
199
  texts = torch.stack([torch.LongTensor(token.ids) for token in tokens]).unsqueeze(0)
200
  sent_embeds = torch.tensor(embed(captions).numpy())
 
 
 
201
  src_image = valid_transform(Image.open('./demo/%s.png' % source).convert('RGB'))
202
 
203
  stories = []
204
  with torch.no_grad():
205
  for i in range(texts.shape[0]):
206
+ candidates = []
207
+ # for _ in range(n_candidates):
208
+ # if args.split_memory: # if splitting model into CPU/GPU, send src_image from CPU memory
209
+ # pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0),
210
+ # sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
211
+ # prompt=None, n_candidates=1, device=device).cpu()
212
+ # else:
213
+ # pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
214
+ # sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
215
+ # prompt=None, n_candidates=1).cpu()
216
+ # print(pixels.shape)
217
+ # candidates.append(pixels.squeeze())
218
+ # stories.append(torch.stack(candidates))
219
+ #with torch.cuda.amp.autocast():
220
+
221
  pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
222
+ sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
223
+ prompt=None, n_candidates=n_candidates).cpu()
224
  stories.append(pixels)
 
225
  img = save_story_results(stories, video_len=4, n_candidates=n_candidates, mask=mask)
226
+ save_image(img, img_file_path, normalize=True)
227
+
228
+ else:
229
+ img_file_path = "gradio_demo_pororo.png"
230
 
231
+ return img_file_path
232
 
233
  with gr.Blocks(css='#output {width:750px; height:750px; float:left;}') as demo:
234
  gr.Markdown('''
 
263
  Here are some examples of generated visual stories for the above-mentioned settings.
264
 
265
  <p align="center">
266
+ <img src="file/demo_pororo_good_v1.png" width="1000">
267
  </p>
268
 
269
  Due to the small training dataset size for story visualization, the model has poor generalization to some unseen settings. The model struggles to generate coherent images in the following scenarios.
 
329
  \[4\] Sharma, Piyush, et al. "Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2018.
330
  ''')
331
 
332
+ demo.launch(share=False)
333
 
334
 
335
  if __name__ == "__main__":
336
+
337
  args_list = ['--model_name_or_path', './ckpt/25.pth',
338
  '--prefix_model_name_or_path', './1.3B/',
339
  '--dataset_name', 'pororo',
 
445
  )
446
 
447
  parser.add_argument("--debug", action="store_true", help="Whether to debug the demo.")
448
+ parser.add_argument("--split_memory", action="store_true", help="Whether to split the model into GPU & CPU in the demo.")
449
 
450
  args = parser.parse_args(args_list)
451
 
dalle/models/__init__.py CHANGED
@@ -1094,7 +1094,7 @@ class PromptConditionalDalle(Dalle):
1094
  prompt = self.get_prompt(bsz=5, eval=True)
1095
 
1096
  images = []
1097
- for t in texts:
1098
  pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
1099
  pixels = np.transpose(pixels, (0, 2, 3, 1))
1100
  images.append(pixels)
@@ -1211,7 +1211,6 @@ class StoryDalle(Dalle):
1211
  lowercase=True,
1212
  dropout=None)
1213
 
1214
-
1215
  return model, config_update
1216
 
1217
 
@@ -1224,6 +1223,7 @@ class StoryDalle(Dalle):
1224
  resid_pdrop=hparams.resid_pdrop,
1225
  attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
1226
 
 
1227
  def get_prompt_p5(self, bsz=None, eval=False):
1228
  input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
1229
  temp_control = self.wte(input_tokens)
@@ -1232,6 +1232,7 @@ class StoryDalle(Dalle):
1232
  past_key_values = self.dropout(past_key_values)
1233
  return past_key_values
1234
 
 
1235
  def forward(self,
1236
  images: torch.FloatTensor,
1237
  src_images: Optional[torch.FloatTensor],
@@ -1287,6 +1288,7 @@ class StoryDalle(Dalle):
1287
  # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
1288
  return logits_img, logits_txt, codes
1289
 
 
1290
  @torch.no_grad()
1291
  def sampling(self,
1292
  tokens: torch.LongTensor,
@@ -1327,6 +1329,7 @@ class StoryDalle(Dalle):
1327
 
1328
  #with autocast(enabled=False):
1329
  src_codes = self.stage1.get_codes(source).detach()
 
1330
  src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
1331
  print(tokens.shape, src_codes.shape, prompt.shape)
1332
  if self.config.story.condition:
@@ -1355,6 +1358,7 @@ class StoryDalle(Dalle):
1355
  pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
1356
  return pixels
1357
 
 
1358
  @torch.no_grad()
1359
  def sampling_batch(self,
1360
  tokens: torch.LongTensor,
@@ -1363,10 +1367,8 @@ class StoryDalle(Dalle):
1363
  top_k: int = 256,
1364
  top_p: Optional[float] = None,
1365
  softmax_temperature: float = 1.0,
1366
- num_candidates: int = 96,
1367
  device: str = 'cuda:0',
1368
  use_fp16: bool = True,
1369
- labels=None,
1370
  prompt=None, n_candidates=1) -> torch.FloatTensor:
1371
 
1372
  self.stage1.eval()
@@ -1396,37 +1398,40 @@ class StoryDalle(Dalle):
1396
 
1397
  #with autocast(enabled=False):
1398
  src_codes = self.stage1.get_codes(source).detach()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1399
 
1400
- # repeat inputs to adjust to n_candidates and story length
1401
- src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
1402
- prompt = prompt.repeat(n_candidates, 1, 1)
1403
- pos_enc_prompt = pos_enc_prompt.repeat(n_candidates, 1)
1404
- tokens = tokens.repeat(n_candidates, 1)
1405
- print(tokens.shape, src_codes.shape, prompt.shape, pos_enc_prompt.shape)
1406
- if self.config.story.condition:
1407
- codes = sampling_conditional(self.stage2,
1408
- self.cross_attention_idxs,
1409
- self.cross_attention_layers,
1410
- tokens,
1411
- src_codes,
1412
- top_k=top_k,
1413
- top_p=top_p,
1414
- softmax_temperature=softmax_temperature,
1415
- use_fp16=use_fp16,
1416
- prompt=prompt,
1417
- pos_prompt=pos_enc_prompt)
1418
- else:
1419
- codes = sampling(self.stage2,
1420
- tokens,
1421
- top_k=top_k,
1422
- top_p=top_p,
1423
- softmax_temperature=softmax_temperature,
1424
- use_fp16=use_fp16,
1425
- prompt=prompt,
1426
- pos_prompt=pos_enc_prompt)
1427
-
1428
- codes = codes.view(self.config.story.story_len * n_candidates, 16, 16) # [B, 16, 16]
1429
- print(codes.shape)
1430
  pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 3, 256, 256]
1431
  print(pixels.shape)
1432
  return pixels.view(n_candidates, self.config.story.story_len, pixels.shape[-3], pixels.shape[-2], pixels.shape[-1])
@@ -1444,11 +1449,10 @@ class StoryDalle(Dalle):
1444
  pred = pred.view(bs, 16, 16) # [B, 16, 16]
1445
  pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
1446
  pixels = np.transpose(pixels, (0, 2, 3, 1))
1447
-
1448
  prompt = self.get_prompt(bsz=5, eval=True)
1449
 
1450
  images = []
1451
- for t in texts:
1452
  pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
1453
  pixels = np.transpose(pixels, (0, 2, 3, 1))
1454
  images.append(pixels)
 
1094
  prompt = self.get_prompt(bsz=5, eval=True)
1095
 
1096
  images = []
1097
+ for i, t in enumerate(texts):
1098
  pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
1099
  pixels = np.transpose(pixels, (0, 2, 3, 1))
1100
  images.append(pixels)
 
1211
  lowercase=True,
1212
  dropout=None)
1213
 
 
1214
  return model, config_update
1215
 
1216
 
 
1223
  resid_pdrop=hparams.resid_pdrop,
1224
  attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
1225
 
1226
+
1227
  def get_prompt_p5(self, bsz=None, eval=False):
1228
  input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
1229
  temp_control = self.wte(input_tokens)
 
1232
  past_key_values = self.dropout(past_key_values)
1233
  return past_key_values
1234
 
1235
+
1236
  def forward(self,
1237
  images: torch.FloatTensor,
1238
  src_images: Optional[torch.FloatTensor],
 
1288
  # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
1289
  return logits_img, logits_txt, codes
1290
 
1291
+
1292
  @torch.no_grad()
1293
  def sampling(self,
1294
  tokens: torch.LongTensor,
 
1329
 
1330
  #with autocast(enabled=False):
1331
  src_codes = self.stage1.get_codes(source).detach()
1332
+
1333
  src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
1334
  print(tokens.shape, src_codes.shape, prompt.shape)
1335
  if self.config.story.condition:
 
1358
  pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
1359
  return pixels
1360
 
1361
+
1362
  @torch.no_grad()
1363
  def sampling_batch(self,
1364
  tokens: torch.LongTensor,
 
1367
  top_k: int = 256,
1368
  top_p: Optional[float] = None,
1369
  softmax_temperature: float = 1.0,
 
1370
  device: str = 'cuda:0',
1371
  use_fp16: bool = True,
 
1372
  prompt=None, n_candidates=1) -> torch.FloatTensor:
1373
 
1374
  self.stage1.eval()
 
1398
 
1399
  #with autocast(enabled=False):
1400
  src_codes = self.stage1.get_codes(source).detach()
1401
+ # src_codes = src_codes.to(device=device) #ensure that src_codes is moved to GPU in case VQGAN was kept in CPU
1402
+
1403
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
1404
+ # repeat inputs to adjust to n_candidates and story length
1405
+ src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
1406
+ prompt = prompt.repeat(n_candidates, 1, 1)
1407
+ pos_enc_prompt = pos_enc_prompt.repeat(n_candidates, 1)
1408
+ tokens = tokens.repeat(n_candidates, 1)
1409
+ print(tokens.shape, src_codes.shape, prompt.shape, pos_enc_prompt.shape)
1410
+ if self.config.story.condition:
1411
+ codes = sampling_conditional(self.stage2,
1412
+ self.cross_attention_idxs,
1413
+ self.cross_attention_layers,
1414
+ tokens,
1415
+ src_codes,
1416
+ top_k=top_k,
1417
+ top_p=top_p,
1418
+ softmax_temperature=softmax_temperature,
1419
+ use_fp16=use_fp16,
1420
+ prompt=prompt,
1421
+ pos_prompt=pos_enc_prompt)
1422
+ else:
1423
+ codes = sampling(self.stage2,
1424
+ tokens,
1425
+ top_k=top_k,
1426
+ top_p=top_p,
1427
+ softmax_temperature=softmax_temperature,
1428
+ use_fp16=use_fp16,
1429
+ prompt=prompt,
1430
+ pos_prompt=pos_enc_prompt)
1431
+
1432
+ codes = codes.view(self.config.story.story_len * n_candidates, 16, 16) # [B, 16, 16]
1433
+ print(codes.shape)
1434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1435
  pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 3, 256, 256]
1436
  print(pixels.shape)
1437
  return pixels.view(n_candidates, self.config.story.story_len, pixels.shape[-3], pixels.shape[-2], pixels.shape[-1])
 
1449
  pred = pred.view(bs, 16, 16) # [B, 16, 16]
1450
  pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
1451
  pixels = np.transpose(pixels, (0, 2, 3, 1))
 
1452
  prompt = self.get_prompt(bsz=5, eval=True)
1453
 
1454
  images = []
1455
+ for i, t in enumerate(texts):
1456
  pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
1457
  pixels = np.transpose(pixels, (0, 2, 3, 1))
1458
  images.append(pixels)
dalle/models/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/dalle/models/__pycache__/__init__.cpython-38.pyc and b/dalle/models/__pycache__/__init__.cpython-38.pyc differ
 
dalle/models/stage2/__pycache__/layers.cpython-38.pyc CHANGED
Binary files a/dalle/models/stage2/__pycache__/layers.cpython-38.pyc and b/dalle/models/stage2/__pycache__/layers.cpython-38.pyc differ
 
dalle/models/stage2/layers.py CHANGED
@@ -182,8 +182,13 @@ class Block(nn.Module):
182
  def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past=None):
183
  attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
184
  x = x + attn
185
- c_attn = cross_attn_layer(x, context, context_mask)
186
- x = x + c_attn
 
 
 
 
 
187
  x = x + self.mlp(self.ln2(x))
188
  return x, present
189
 
 
182
  def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past=None):
183
  attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
184
  x = x + attn
185
+
186
+ c_attn = cross_attn_layer(x.to(device=context.device),
187
+ context,
188
+ context_mask.to(device=context.device))
189
+
190
+ x = x + c_attn.to(device=x.device)
191
+
192
  x = x + self.mlp(self.ln2(x))
193
  return x, present
194
 
gradio_demo_pororo.png ADDED

Git LFS Details

  • SHA256: 8f1e899b65857530477e5a37d333c5853b2c87122dbe4e5f70c4591d881ee66b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
requirements.txt CHANGED
@@ -10,3 +10,4 @@ pytorch-lightning
10
  einops
11
  tokenizers
12
  tensorflow
 
 
10
  einops
11
  tokenizers
12
  tensorflow
13
+ allennlp==2.10.0