Spaces:
Build error
Build error
adymaharana
commited on
Commit
•
908bed5
1
Parent(s):
1cac669
fp16 version
Browse files- app.py +113 -18
- dalle/models/__init__.py +40 -36
- dalle/models/__pycache__/__init__.cpython-38.pyc +0 -0
- dalle/models/stage2/__pycache__/layers.cpython-38.pyc +0 -0
- dalle/models/stage2/layers.py +7 -2
- gradio_demo_pororo.png +3 -0
- requirements.txt +1 -0
app.py
CHANGED
@@ -6,9 +6,14 @@ from dalle.models import StoryDalle
|
|
6 |
import argparse
|
7 |
from PIL import Image
|
8 |
from torchvision.utils import save_image
|
|
|
9 |
import tensorflow_hub as hub
|
10 |
import gdown
|
|
|
|
|
11 |
|
|
|
|
|
12 |
|
13 |
source_frame_paths = {
|
14 |
'Pororo': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_2/Pororo_ENGLISH1_2_ep6/12.png',
|
@@ -23,6 +28,51 @@ source_frame_paths = {
|
|
23 |
}
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
|
27 |
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
28 |
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
@@ -66,9 +116,10 @@ def save_story_results(images, video_len=4, n_candidates=1, mask=None):
|
|
66 |
|
67 |
|
68 |
def main(args):
|
|
|
69 |
#device = 'cuda:0'
|
70 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
71 |
-
#device = torch.device('cpu')
|
72 |
|
73 |
model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
|
74 |
|
@@ -81,6 +132,9 @@ def main(args):
|
|
81 |
#assert os.path.exists("./ckpt/25.pth")
|
82 |
gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
|
83 |
|
|
|
|
|
|
|
84 |
if args.debug:
|
85 |
model = None
|
86 |
embed = None
|
@@ -88,13 +142,20 @@ def main(args):
|
|
88 |
model, config = StoryDalle.from_pretrained(args)
|
89 |
model.tokenizer.add_tokens(['pororo', 'loopy', 'eddy', 'harry', 'poby', 'tongtong', 'crong', 'rody', 'petty'])
|
90 |
model.eval()
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5")
|
93 |
|
94 |
-
if model.config.story.condition:
|
95 |
-
for i in range(len(model.cross_attention_layers)):
|
96 |
-
model.cross_attention_layers[i].to(device)
|
97 |
-
print("Cross-attention layers are in cuda:", next(model.cross_attention_layers[0].parameters()).is_cuda)
|
98 |
|
99 |
valid_transform = transforms.Compose(
|
100 |
[transforms.Resize(config.dataset.image_resolution),
|
@@ -103,6 +164,8 @@ def main(args):
|
|
103 |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
104 |
)
|
105 |
|
|
|
|
|
106 |
#torch.save(model, './ckpt/checkpoint.pt')
|
107 |
#sys.exit()
|
108 |
|
@@ -110,32 +173,62 @@ def main(args):
|
|
110 |
supercondition=False):
|
111 |
|
112 |
if not args.debug:
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
mask = [1 if caption != '' else 0 for caption in captions]
|
|
|
|
|
|
|
|
|
115 |
print(captions, mask, source, n_candidates)
|
|
|
|
|
|
|
116 |
for i, caption in enumerate(captions):
|
117 |
if caption == "":
|
118 |
-
captions[i] = "Pororo is reading a book."
|
|
|
119 |
tokens = [model.tokenizer.encode(caption) for caption in captions]
|
120 |
texts = torch.stack([torch.LongTensor(token.ids) for token in tokens]).unsqueeze(0)
|
121 |
sent_embeds = torch.tensor(embed(captions).numpy())
|
122 |
-
# sent_embeds = torch.tensor(description_vecs[source_frame_paths[source].
|
123 |
-
# replace('/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/', '')[:-4]][0]).unsqueeze(0).repeat(4, 1)
|
124 |
-
|
125 |
src_image = valid_transform(Image.open('./demo/%s.png' % source).convert('RGB'))
|
126 |
|
127 |
stories = []
|
128 |
with torch.no_grad():
|
129 |
for i in range(texts.shape[0]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
|
131 |
-
|
132 |
-
|
133 |
stories.append(pixels)
|
134 |
-
|
135 |
img = save_story_results(stories, video_len=4, n_candidates=n_candidates, mask=mask)
|
136 |
-
save_image(img,
|
|
|
|
|
|
|
137 |
|
138 |
-
return
|
139 |
|
140 |
with gr.Blocks(css='#output {width:750px; height:750px; float:left;}') as demo:
|
141 |
gr.Markdown('''
|
@@ -170,7 +263,7 @@ def main(args):
|
|
170 |
Here are some examples of generated visual stories for the above-mentioned settings.
|
171 |
|
172 |
<p align="center">
|
173 |
-
<img src="file/
|
174 |
</p>
|
175 |
|
176 |
Due to the small training dataset size for story visualization, the model has poor generalization to some unseen settings. The model struggles to generate coherent images in the following scenarios.
|
@@ -236,10 +329,11 @@ def main(args):
|
|
236 |
\[4\] Sharma, Piyush, et al. "Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2018.
|
237 |
''')
|
238 |
|
239 |
-
demo.launch(share=
|
240 |
|
241 |
|
242 |
if __name__ == "__main__":
|
|
|
243 |
args_list = ['--model_name_or_path', './ckpt/25.pth',
|
244 |
'--prefix_model_name_or_path', './1.3B/',
|
245 |
'--dataset_name', 'pororo',
|
@@ -351,6 +445,7 @@ if __name__ == "__main__":
|
|
351 |
)
|
352 |
|
353 |
parser.add_argument("--debug", action="store_true", help="Whether to debug the demo.")
|
|
|
354 |
|
355 |
args = parser.parse_args(args_list)
|
356 |
|
|
|
6 |
import argparse
|
7 |
from PIL import Image
|
8 |
from torchvision.utils import save_image
|
9 |
+
import tensorflow as tf
|
10 |
import tensorflow_hub as hub
|
11 |
import gdown
|
12 |
+
from allennlp.predictors.predictor import Predictor
|
13 |
+
import random
|
14 |
|
15 |
+
torch.set_grad_enabled(False)
|
16 |
+
tf.config.set_visible_devices([], 'GPU') # setting Tensorflow's GPU visibility to None to constraing embedding model to CPU
|
17 |
|
18 |
source_frame_paths = {
|
19 |
'Pororo': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_2/Pororo_ENGLISH1_2_ep6/12.png',
|
|
|
28 |
}
|
29 |
|
30 |
|
31 |
+
def get_span_words(span, document):
|
32 |
+
return ' '.join(document[span[0]:span[1]+1])
|
33 |
+
|
34 |
+
|
35 |
+
def print_clusters(prediction):
|
36 |
+
document, clusters = prediction['document'], prediction['clusters']
|
37 |
+
for cluster in clusters:
|
38 |
+
print(get_span_words(cluster[0], document) + ': ', end='')
|
39 |
+
print(f"[{'; '.join([get_span_words(span, document) for span in cluster])}]")
|
40 |
+
|
41 |
+
|
42 |
+
def resolve_coref(captions, captions_mask, coref_predictor):
|
43 |
+
sent_counts = []
|
44 |
+
doc = ''
|
45 |
+
for cap, mask in zip(captions, captions_mask):
|
46 |
+
if mask == 0:
|
47 |
+
sent_counts.append(0)
|
48 |
+
else:
|
49 |
+
print(cap)
|
50 |
+
count = len([c.strip() for c in cap.split('.') if c.strip()])
|
51 |
+
sent_counts.append(count)
|
52 |
+
doc += cap + ' '
|
53 |
+
|
54 |
+
# print(doc)
|
55 |
+
|
56 |
+
doc = doc.strip()
|
57 |
+
resolved_doc = coref_predictor.coref_resolved(doc)
|
58 |
+
# print(resolved_doc)
|
59 |
+
# print(sent_counts)
|
60 |
+
|
61 |
+
sents = resolved_doc.split('. ')
|
62 |
+
resolved_captions = []
|
63 |
+
for i, (count, mask) in enumerate(zip(sent_counts, captions_mask)):
|
64 |
+
if mask == 0:
|
65 |
+
resolved_captions.append('')
|
66 |
+
else:
|
67 |
+
new_cap = '. '.join(sents[sum(sent_counts[:i]):sum(sent_counts[:i]) + count])
|
68 |
+
new_cap = new_cap.strip()
|
69 |
+
if new_cap[-1] not in ['!', '?', '.']:
|
70 |
+
new_cap += '.'
|
71 |
+
resolved_captions.append(new_cap)
|
72 |
+
|
73 |
+
return resolved_captions
|
74 |
+
|
75 |
+
|
76 |
def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
|
77 |
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
78 |
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
|
|
116 |
|
117 |
|
118 |
def main(args):
|
119 |
+
|
120 |
#device = 'cuda:0'
|
121 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
122 |
+
# device = torch.device('cpu')
|
123 |
|
124 |
model_url = 'https://drive.google.com/u/1/uc?id=1KAXVtE8lEE2Yc83VY7w6ycOOMkdWbmJo&export=sharing'
|
125 |
|
|
|
132 |
#assert os.path.exists("./ckpt/25.pth")
|
133 |
gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
|
134 |
|
135 |
+
coref_model_url = 'https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz'
|
136 |
+
coref_predictor = Predictor.from_path(coref_model_url)
|
137 |
+
|
138 |
if args.debug:
|
139 |
model = None
|
140 |
embed = None
|
|
|
142 |
model, config = StoryDalle.from_pretrained(args)
|
143 |
model.tokenizer.add_tokens(['pororo', 'loopy', 'eddy', 'harry', 'poby', 'tongtong', 'crong', 'rody', 'petty'])
|
144 |
model.eval()
|
145 |
+
# split_model into CPU and GPU
|
146 |
+
if args.split_memory:
|
147 |
+
model.stage2.to(device=device)
|
148 |
+
model.story_linear.to(device=device)
|
149 |
+
model.story_block.to(device=device)
|
150 |
+
else:
|
151 |
+
model.to(device=device)
|
152 |
+
if model.config.story.condition:
|
153 |
+
for i in range(len(model.cross_attention_layers)):
|
154 |
+
model.cross_attention_layers[i].to(device)
|
155 |
+
print("Cross-attention layers are in cuda:", next(model.cross_attention_layers[0].parameters()).is_cuda)
|
156 |
+
|
157 |
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5")
|
158 |
|
|
|
|
|
|
|
|
|
159 |
|
160 |
valid_transform = transforms.Compose(
|
161 |
[transforms.Resize(config.dataset.image_resolution),
|
|
|
164 |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
165 |
)
|
166 |
|
167 |
+
print("Model is in ", model.device)
|
168 |
+
|
169 |
#torch.save(model, './ckpt/checkpoint.pt')
|
170 |
#sys.exit()
|
171 |
|
|
|
173 |
supercondition=False):
|
174 |
|
175 |
if not args.debug:
|
176 |
+
|
177 |
+
suffix = random.randint(0, 1000)
|
178 |
+
img_file_path = "./demo/images/gradio_demo_pororo_%s.png" % suffix
|
179 |
+
txt_file_path = "./demo/texts/gradio_demo_pororo_%s.txt" % suffix
|
180 |
+
|
181 |
+
captions = [caption_1.strip(), caption_2.strip(), caption_3.strip(), caption_4.strip()]
|
182 |
+
for i in range(len(captions)):
|
183 |
+
if captions[i][-1] not in ['!', '?', '.']:
|
184 |
+
captions[i] = captions[i] + '.'
|
185 |
mask = [1 if caption != '' else 0 for caption in captions]
|
186 |
+
|
187 |
+
with open(txt_file_path, 'w') as f:
|
188 |
+
f.write('\n'.join(captions))
|
189 |
+
|
190 |
print(captions, mask, source, n_candidates)
|
191 |
+
captions = resolve_coref(captions, mask, coref_predictor)
|
192 |
+
print(captions)
|
193 |
+
|
194 |
for i, caption in enumerate(captions):
|
195 |
if caption == "":
|
196 |
+
captions[i] = "Pororo is reading a book." # filler for shorter captions
|
197 |
+
|
198 |
tokens = [model.tokenizer.encode(caption) for caption in captions]
|
199 |
texts = torch.stack([torch.LongTensor(token.ids) for token in tokens]).unsqueeze(0)
|
200 |
sent_embeds = torch.tensor(embed(captions).numpy())
|
|
|
|
|
|
|
201 |
src_image = valid_transform(Image.open('./demo/%s.png' % source).convert('RGB'))
|
202 |
|
203 |
stories = []
|
204 |
with torch.no_grad():
|
205 |
for i in range(texts.shape[0]):
|
206 |
+
candidates = []
|
207 |
+
# for _ in range(n_candidates):
|
208 |
+
# if args.split_memory: # if splitting model into CPU/GPU, send src_image from CPU memory
|
209 |
+
# pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0),
|
210 |
+
# sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
|
211 |
+
# prompt=None, n_candidates=1, device=device).cpu()
|
212 |
+
# else:
|
213 |
+
# pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
|
214 |
+
# sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
|
215 |
+
# prompt=None, n_candidates=1).cpu()
|
216 |
+
# print(pixels.shape)
|
217 |
+
# candidates.append(pixels.squeeze())
|
218 |
+
# stories.append(torch.stack(candidates))
|
219 |
+
#with torch.cuda.amp.autocast():
|
220 |
+
|
221 |
pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
|
222 |
+
sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
|
223 |
+
prompt=None, n_candidates=n_candidates).cpu()
|
224 |
stories.append(pixels)
|
|
|
225 |
img = save_story_results(stories, video_len=4, n_candidates=n_candidates, mask=mask)
|
226 |
+
save_image(img, img_file_path, normalize=True)
|
227 |
+
|
228 |
+
else:
|
229 |
+
img_file_path = "gradio_demo_pororo.png"
|
230 |
|
231 |
+
return img_file_path
|
232 |
|
233 |
with gr.Blocks(css='#output {width:750px; height:750px; float:left;}') as demo:
|
234 |
gr.Markdown('''
|
|
|
263 |
Here are some examples of generated visual stories for the above-mentioned settings.
|
264 |
|
265 |
<p align="center">
|
266 |
+
<img src="file/demo_pororo_good_v1.png" width="1000">
|
267 |
</p>
|
268 |
|
269 |
Due to the small training dataset size for story visualization, the model has poor generalization to some unseen settings. The model struggles to generate coherent images in the following scenarios.
|
|
|
329 |
\[4\] Sharma, Piyush, et al. "Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2018.
|
330 |
''')
|
331 |
|
332 |
+
demo.launch(share=False)
|
333 |
|
334 |
|
335 |
if __name__ == "__main__":
|
336 |
+
|
337 |
args_list = ['--model_name_or_path', './ckpt/25.pth',
|
338 |
'--prefix_model_name_or_path', './1.3B/',
|
339 |
'--dataset_name', 'pororo',
|
|
|
445 |
)
|
446 |
|
447 |
parser.add_argument("--debug", action="store_true", help="Whether to debug the demo.")
|
448 |
+
parser.add_argument("--split_memory", action="store_true", help="Whether to split the model into GPU & CPU in the demo.")
|
449 |
|
450 |
args = parser.parse_args(args_list)
|
451 |
|
dalle/models/__init__.py
CHANGED
@@ -1094,7 +1094,7 @@ class PromptConditionalDalle(Dalle):
|
|
1094 |
prompt = self.get_prompt(bsz=5, eval=True)
|
1095 |
|
1096 |
images = []
|
1097 |
-
for t in texts:
|
1098 |
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
|
1099 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
1100 |
images.append(pixels)
|
@@ -1211,7 +1211,6 @@ class StoryDalle(Dalle):
|
|
1211 |
lowercase=True,
|
1212 |
dropout=None)
|
1213 |
|
1214 |
-
|
1215 |
return model, config_update
|
1216 |
|
1217 |
|
@@ -1224,6 +1223,7 @@ class StoryDalle(Dalle):
|
|
1224 |
resid_pdrop=hparams.resid_pdrop,
|
1225 |
attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
|
1226 |
|
|
|
1227 |
def get_prompt_p5(self, bsz=None, eval=False):
|
1228 |
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
|
1229 |
temp_control = self.wte(input_tokens)
|
@@ -1232,6 +1232,7 @@ class StoryDalle(Dalle):
|
|
1232 |
past_key_values = self.dropout(past_key_values)
|
1233 |
return past_key_values
|
1234 |
|
|
|
1235 |
def forward(self,
|
1236 |
images: torch.FloatTensor,
|
1237 |
src_images: Optional[torch.FloatTensor],
|
@@ -1287,6 +1288,7 @@ class StoryDalle(Dalle):
|
|
1287 |
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
|
1288 |
return logits_img, logits_txt, codes
|
1289 |
|
|
|
1290 |
@torch.no_grad()
|
1291 |
def sampling(self,
|
1292 |
tokens: torch.LongTensor,
|
@@ -1327,6 +1329,7 @@ class StoryDalle(Dalle):
|
|
1327 |
|
1328 |
#with autocast(enabled=False):
|
1329 |
src_codes = self.stage1.get_codes(source).detach()
|
|
|
1330 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
|
1331 |
print(tokens.shape, src_codes.shape, prompt.shape)
|
1332 |
if self.config.story.condition:
|
@@ -1355,6 +1358,7 @@ class StoryDalle(Dalle):
|
|
1355 |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
1356 |
return pixels
|
1357 |
|
|
|
1358 |
@torch.no_grad()
|
1359 |
def sampling_batch(self,
|
1360 |
tokens: torch.LongTensor,
|
@@ -1363,10 +1367,8 @@ class StoryDalle(Dalle):
|
|
1363 |
top_k: int = 256,
|
1364 |
top_p: Optional[float] = None,
|
1365 |
softmax_temperature: float = 1.0,
|
1366 |
-
num_candidates: int = 96,
|
1367 |
device: str = 'cuda:0',
|
1368 |
use_fp16: bool = True,
|
1369 |
-
labels=None,
|
1370 |
prompt=None, n_candidates=1) -> torch.FloatTensor:
|
1371 |
|
1372 |
self.stage1.eval()
|
@@ -1396,37 +1398,40 @@ class StoryDalle(Dalle):
|
|
1396 |
|
1397 |
#with autocast(enabled=False):
|
1398 |
src_codes = self.stage1.get_codes(source).detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1399 |
|
1400 |
-
# repeat inputs to adjust to n_candidates and story length
|
1401 |
-
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
|
1402 |
-
prompt = prompt.repeat(n_candidates, 1, 1)
|
1403 |
-
pos_enc_prompt = pos_enc_prompt.repeat(n_candidates, 1)
|
1404 |
-
tokens = tokens.repeat(n_candidates, 1)
|
1405 |
-
print(tokens.shape, src_codes.shape, prompt.shape, pos_enc_prompt.shape)
|
1406 |
-
if self.config.story.condition:
|
1407 |
-
codes = sampling_conditional(self.stage2,
|
1408 |
-
self.cross_attention_idxs,
|
1409 |
-
self.cross_attention_layers,
|
1410 |
-
tokens,
|
1411 |
-
src_codes,
|
1412 |
-
top_k=top_k,
|
1413 |
-
top_p=top_p,
|
1414 |
-
softmax_temperature=softmax_temperature,
|
1415 |
-
use_fp16=use_fp16,
|
1416 |
-
prompt=prompt,
|
1417 |
-
pos_prompt=pos_enc_prompt)
|
1418 |
-
else:
|
1419 |
-
codes = sampling(self.stage2,
|
1420 |
-
tokens,
|
1421 |
-
top_k=top_k,
|
1422 |
-
top_p=top_p,
|
1423 |
-
softmax_temperature=softmax_temperature,
|
1424 |
-
use_fp16=use_fp16,
|
1425 |
-
prompt=prompt,
|
1426 |
-
pos_prompt=pos_enc_prompt)
|
1427 |
-
|
1428 |
-
codes = codes.view(self.config.story.story_len * n_candidates, 16, 16) # [B, 16, 16]
|
1429 |
-
print(codes.shape)
|
1430 |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 3, 256, 256]
|
1431 |
print(pixels.shape)
|
1432 |
return pixels.view(n_candidates, self.config.story.story_len, pixels.shape[-3], pixels.shape[-2], pixels.shape[-1])
|
@@ -1444,11 +1449,10 @@ class StoryDalle(Dalle):
|
|
1444 |
pred = pred.view(bs, 16, 16) # [B, 16, 16]
|
1445 |
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
|
1446 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
1447 |
-
|
1448 |
prompt = self.get_prompt(bsz=5, eval=True)
|
1449 |
|
1450 |
images = []
|
1451 |
-
for t in texts:
|
1452 |
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
|
1453 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
1454 |
images.append(pixels)
|
|
|
1094 |
prompt = self.get_prompt(bsz=5, eval=True)
|
1095 |
|
1096 |
images = []
|
1097 |
+
for i, t in enumerate(texts):
|
1098 |
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
|
1099 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
1100 |
images.append(pixels)
|
|
|
1211 |
lowercase=True,
|
1212 |
dropout=None)
|
1213 |
|
|
|
1214 |
return model, config_update
|
1215 |
|
1216 |
|
|
|
1223 |
resid_pdrop=hparams.resid_pdrop,
|
1224 |
attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
|
1225 |
|
1226 |
+
|
1227 |
def get_prompt_p5(self, bsz=None, eval=False):
|
1228 |
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
|
1229 |
temp_control = self.wte(input_tokens)
|
|
|
1232 |
past_key_values = self.dropout(past_key_values)
|
1233 |
return past_key_values
|
1234 |
|
1235 |
+
|
1236 |
def forward(self,
|
1237 |
images: torch.FloatTensor,
|
1238 |
src_images: Optional[torch.FloatTensor],
|
|
|
1288 |
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
|
1289 |
return logits_img, logits_txt, codes
|
1290 |
|
1291 |
+
|
1292 |
@torch.no_grad()
|
1293 |
def sampling(self,
|
1294 |
tokens: torch.LongTensor,
|
|
|
1329 |
|
1330 |
#with autocast(enabled=False):
|
1331 |
src_codes = self.stage1.get_codes(source).detach()
|
1332 |
+
|
1333 |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
|
1334 |
print(tokens.shape, src_codes.shape, prompt.shape)
|
1335 |
if self.config.story.condition:
|
|
|
1358 |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
1359 |
return pixels
|
1360 |
|
1361 |
+
|
1362 |
@torch.no_grad()
|
1363 |
def sampling_batch(self,
|
1364 |
tokens: torch.LongTensor,
|
|
|
1367 |
top_k: int = 256,
|
1368 |
top_p: Optional[float] = None,
|
1369 |
softmax_temperature: float = 1.0,
|
|
|
1370 |
device: str = 'cuda:0',
|
1371 |
use_fp16: bool = True,
|
|
|
1372 |
prompt=None, n_candidates=1) -> torch.FloatTensor:
|
1373 |
|
1374 |
self.stage1.eval()
|
|
|
1398 |
|
1399 |
#with autocast(enabled=False):
|
1400 |
src_codes = self.stage1.get_codes(source).detach()
|
1401 |
+
# src_codes = src_codes.to(device=device) #ensure that src_codes is moved to GPU in case VQGAN was kept in CPU
|
1402 |
+
|
1403 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
1404 |
+
# repeat inputs to adjust to n_candidates and story length
|
1405 |
+
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
|
1406 |
+
prompt = prompt.repeat(n_candidates, 1, 1)
|
1407 |
+
pos_enc_prompt = pos_enc_prompt.repeat(n_candidates, 1)
|
1408 |
+
tokens = tokens.repeat(n_candidates, 1)
|
1409 |
+
print(tokens.shape, src_codes.shape, prompt.shape, pos_enc_prompt.shape)
|
1410 |
+
if self.config.story.condition:
|
1411 |
+
codes = sampling_conditional(self.stage2,
|
1412 |
+
self.cross_attention_idxs,
|
1413 |
+
self.cross_attention_layers,
|
1414 |
+
tokens,
|
1415 |
+
src_codes,
|
1416 |
+
top_k=top_k,
|
1417 |
+
top_p=top_p,
|
1418 |
+
softmax_temperature=softmax_temperature,
|
1419 |
+
use_fp16=use_fp16,
|
1420 |
+
prompt=prompt,
|
1421 |
+
pos_prompt=pos_enc_prompt)
|
1422 |
+
else:
|
1423 |
+
codes = sampling(self.stage2,
|
1424 |
+
tokens,
|
1425 |
+
top_k=top_k,
|
1426 |
+
top_p=top_p,
|
1427 |
+
softmax_temperature=softmax_temperature,
|
1428 |
+
use_fp16=use_fp16,
|
1429 |
+
prompt=prompt,
|
1430 |
+
pos_prompt=pos_enc_prompt)
|
1431 |
+
|
1432 |
+
codes = codes.view(self.config.story.story_len * n_candidates, 16, 16) # [B, 16, 16]
|
1433 |
+
print(codes.shape)
|
1434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1435 |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 3, 256, 256]
|
1436 |
print(pixels.shape)
|
1437 |
return pixels.view(n_candidates, self.config.story.story_len, pixels.shape[-3], pixels.shape[-2], pixels.shape[-1])
|
|
|
1449 |
pred = pred.view(bs, 16, 16) # [B, 16, 16]
|
1450 |
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
|
1451 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
|
|
1452 |
prompt = self.get_prompt(bsz=5, eval=True)
|
1453 |
|
1454 |
images = []
|
1455 |
+
for i, t in enumerate(texts):
|
1456 |
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
|
1457 |
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
1458 |
images.append(pixels)
|
dalle/models/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/dalle/models/__pycache__/__init__.cpython-38.pyc and b/dalle/models/__pycache__/__init__.cpython-38.pyc differ
|
|
dalle/models/stage2/__pycache__/layers.cpython-38.pyc
CHANGED
Binary files a/dalle/models/stage2/__pycache__/layers.cpython-38.pyc and b/dalle/models/stage2/__pycache__/layers.cpython-38.pyc differ
|
|
dalle/models/stage2/layers.py
CHANGED
@@ -182,8 +182,13 @@ class Block(nn.Module):
|
|
182 |
def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past=None):
|
183 |
attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
|
184 |
x = x + attn
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
187 |
x = x + self.mlp(self.ln2(x))
|
188 |
return x, present
|
189 |
|
|
|
182 |
def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past=None):
|
183 |
attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
|
184 |
x = x + attn
|
185 |
+
|
186 |
+
c_attn = cross_attn_layer(x.to(device=context.device),
|
187 |
+
context,
|
188 |
+
context_mask.to(device=context.device))
|
189 |
+
|
190 |
+
x = x + c_attn.to(device=x.device)
|
191 |
+
|
192 |
x = x + self.mlp(self.ln2(x))
|
193 |
return x, present
|
194 |
|
gradio_demo_pororo.png
ADDED
Git LFS Details
|
requirements.txt
CHANGED
@@ -10,3 +10,4 @@ pytorch-lightning
|
|
10 |
einops
|
11 |
tokenizers
|
12 |
tensorflow
|
|
|
|
10 |
einops
|
11 |
tokenizers
|
12 |
tensorflow
|
13 |
+
allennlp==2.10.0
|