adymaharana commited on
Commit
3d5e231
1 Parent(s): 84f1d0c

Added files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .idea/.gitignore +8 -0
  2. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  3. .idea/misc.xml +4 -0
  4. .idea/modules.xml +8 -0
  5. .idea/storydalle.iml +12 -0
  6. .idea/vcs.xml +6 -0
  7. 1.3B/config.yaml +38 -0
  8. 1.3B/tokenizer/bpe-16k-merges.txt +0 -0
  9. 1.3B/tokenizer/bpe-16k-vocab.json +0 -0
  10. app.py +353 -0
  11. dalle/__init__.py +0 -0
  12. dalle/__pycache__/__init__.cpython-38.pyc +0 -0
  13. dalle/__pycache__/trainer_prefix.cpython-38.pyc +0 -0
  14. dalle/models/__init__.py +1462 -0
  15. dalle/models/__pycache__/__init__.cpython-38.pyc +0 -0
  16. dalle/models/__pycache__/prefix_tuning_model.cpython-38.pyc +0 -0
  17. dalle/models/__pycache__/tokenizer.cpython-38.pyc +0 -0
  18. dalle/models/stage1/__pycache__/layers.cpython-38.pyc +0 -0
  19. dalle/models/stage1/__pycache__/vqgan.cpython-38.pyc +0 -0
  20. dalle/models/stage1/layers.py +373 -0
  21. dalle/models/stage1/vqgan.py +93 -0
  22. dalle/models/stage2/__pycache__/layers.cpython-38.pyc +0 -0
  23. dalle/models/stage2/__pycache__/transformer.cpython-38.pyc +0 -0
  24. dalle/models/stage2/layers.py +216 -0
  25. dalle/models/stage2/transformer.py +502 -0
  26. dalle/models/tokenizer.py +35 -0
  27. dalle/trainer_prefix.py +1629 -0
  28. dalle/utils/__init__.py +3 -0
  29. dalle/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  30. dalle/utils/__pycache__/config.cpython-38.pyc +0 -0
  31. dalle/utils/__pycache__/sampling.cpython-38.pyc +0 -0
  32. dalle/utils/__pycache__/utils.cpython-38.pyc +0 -0
  33. dalle/utils/config.py +209 -0
  34. dalle/utils/sampling.py +369 -0
  35. dalle/utils/utils.py +131 -0
  36. demo/Barney.png +0 -0
  37. demo/Betty.png +0 -0
  38. demo/Crong.png +0 -0
  39. demo/Dino.png +0 -0
  40. demo/Eddy.png +0 -0
  41. demo/Fred.png +0 -0
  42. demo/Harry.png +0 -0
  43. demo/Loopy.png +0 -0
  44. demo/MrSlate.png +0 -0
  45. demo/Pebbles.png +0 -0
  46. demo/Petty.png +0 -0
  47. demo/Poby.png +0 -0
  48. demo/Pororo.png +0 -0
  49. demo/Rody.png +0 -0
  50. demo/Tongtong.png +0 -0
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/storydalle.iml" filepath="$PROJECT_DIR$/.idea/storydalle.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/storydalle.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
1.3B/config.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ tokenizer_type: CharBPE
3
+ context_length: 64
4
+ image_resolution: 256
5
+
6
+ stage1:
7
+ type: vqgan
8
+ embed_dim: 256
9
+ n_embed: 16384
10
+ hparams:
11
+ double_z: False
12
+ z_channels: 256
13
+ resolution: 256
14
+ in_channels: 3
15
+ out_ch: 3
16
+ ch: 128
17
+ ch_mult: [1, 1, 2, 2, 4]
18
+ num_res_blocks: 2
19
+ attn_resolutions: [16]
20
+ pdrop: 0.0
21
+
22
+ stage2:
23
+ type: transformer1d
24
+ vocab_size_txt: 16384
25
+ vocab_size_img: 16384
26
+ hparams:
27
+ embed_dim: 1536
28
+ n_layers: 42
29
+ n_heads: 24
30
+ n_dense_layers: 42
31
+ ctx_len_img: 256
32
+ ctx_len_txt: 64
33
+ embd_pdrop: 0.0
34
+ resid_pdrop: 0.0
35
+ attn_pdrop: 0.0
36
+ mlp_bias: True
37
+ attn_bias: True
38
+ gelu_use_approx: False
1.3B/tokenizer/bpe-16k-merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
1.3B/tokenizer/bpe-16k-vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ import gradio as gr
3
+ import torchvision.utils as vutils
4
+ import torchvision.transforms as transforms
5
+ from dalle.models import StoryDalle
6
+ import argparse
7
+ from PIL import Image
8
+ import numpy as np
9
+ from torchvision.utils import save_image
10
+ import tensorflow_hub as hub
11
+ import gdown
12
+
13
+
14
+ source_frame_paths = {
15
+ 'Pororo': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_2/Pororo_ENGLISH1_2_ep6/12.png',
16
+ 'Loopy': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_1/Pororo_ENGLISH1_1_ep12/26.png',
17
+ 'Crong': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_1/Pororo_ENGLISH1_1_ep12/10.png',
18
+ 'Poby': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_1/Pororo_ENGLISH1_1_ep9/34.png',
19
+ 'Eddy': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_1/Pororo_ENGLISH1_1_ep12/46.png',
20
+ 'Petty': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH2_1/Pororo_ENGLISH2_1_ep1/34.png',
21
+ 'Tongtong': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH3_1/Pororo_ENGLISH3_1_ep7/8.png',
22
+ 'Rody': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH3_1/Pororo_ENGLISH3_1_ep6/66.png',
23
+ 'Harry': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH3_1/Pororo_ENGLISH3_1_ep7/39.png',
24
+ }
25
+
26
+
27
+ def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
28
+ mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
29
+ std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
30
+ if mean.ndim == 1:
31
+ mean = mean.view(-1, 1, 1)
32
+ if std.ndim == 1:
33
+ std = std.view(-1, 1, 1)
34
+ tensor.mul_(std).add_(mean)
35
+ return tensor
36
+
37
+
38
+ def save_story_results(images, video_len=4, n_candidates=1, mask=None):
39
+ # print("Generated Images shape: ", images.shape)
40
+
41
+ if mask is None:
42
+ mask = [1 for _ in range(len(video_len))]
43
+
44
+ all_images = []
45
+ for i in range(len(images)): # batch size = 1
46
+ for j in range(n_candidates):
47
+ story = []
48
+ for k, m in enumerate(mask):
49
+ if m == 1:
50
+ story.append(images[i][j][k])
51
+ all_images.append(vutils.make_grid(story, sum(mask), padding=0))
52
+ all_images = vutils.make_grid(all_images, 1, padding=20)
53
+ print(all_images)
54
+
55
+ pad_len = video_len - sum(mask)
56
+
57
+ if pad_len > 0:
58
+ pad_height = 256 * n_candidates + 20 * (n_candidates + 1)
59
+ pad_width = 256 * pad_len + 20 * (pad_len)
60
+ pad_image = torch.ones(3, pad_height, pad_width)
61
+
62
+ print(all_images.shape, pad_image.shape)
63
+ all_images = torch.cat([all_images[:, :, :-15], pad_image], dim=-1)
64
+
65
+ print(all_images.shape)
66
+ return all_images[:, 15:-15, 15:-15]
67
+
68
+
69
+ def main(args):
70
+ device = 'cuda:0'
71
+
72
+ model_url = 'https://drive.google.com/file/d/1lJ6zMZ6qTvFu6H35-VEdFlN13MMslivJ/view?usp=sharing'
73
+ png_url = 'https://drive.google.com/file/d/1C33A1IzSHDPoQ4QBsgFWbF61QWaAxRo_/view?usp=sharing'
74
+
75
+ gdown.download(model_url, quiet=True, use_cookies=False, output="./ckpt/25.pth")
76
+ gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
77
+
78
+ if args.debug:
79
+ model = None
80
+ embed = None
81
+ else:
82
+ model, config = StoryDalle.from_pretrained(args)
83
+ model.tokenizer.add_tokens(['pororo', 'loopy', 'eddy', 'harry', 'poby', 'tongtong', 'crong', 'rody', 'petty'])
84
+ model.eval()
85
+ model.to(device=device)
86
+ embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5")
87
+
88
+ if model.config.story.condition:
89
+ for i in range(len(model.cross_attention_layers)):
90
+ model.cross_attention_layers[i].to(device)
91
+ print("Cross-attention layers are in cuda:", next(model.cross_attention_layers[0].parameters()).is_cuda)
92
+
93
+ valid_transform = transforms.Compose(
94
+ [transforms.Resize(config.dataset.image_resolution),
95
+ transforms.CenterCrop(config.dataset.image_resolution),
96
+ transforms.ToTensor(),
97
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
98
+ )
99
+
100
+ def predict(caption_1, caption_2, caption_3, caption_4, source='Pororo', top_k=32, top_p=0.2, n_candidates=4,
101
+ supercondition=False):
102
+
103
+ if not args.debug:
104
+ captions = [caption_1, caption_2, caption_3, caption_4]
105
+ mask = [1 if caption != '' else 0 for caption in captions]
106
+ print(captions, mask, source, n_candidates)
107
+ for i, caption in enumerate(captions):
108
+ if caption == "":
109
+ captions[i] = "Pororo is reading a book."
110
+ tokens = [model.tokenizer.encode(caption) for caption in captions]
111
+ texts = torch.stack([torch.LongTensor(token.ids) for token in tokens]).unsqueeze(0)
112
+ sent_embeds = torch.tensor(embed(captions).numpy())
113
+ # sent_embeds = torch.tensor(description_vecs[source_frame_paths[source].
114
+ # replace('/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/', '')[:-4]][0]).unsqueeze(0).repeat(4, 1)
115
+
116
+ src_image = valid_transform(Image.open('./demo/%s.png' % source).convert('RGB'))
117
+
118
+ stories = []
119
+ with torch.no_grad():
120
+ for i in range(texts.shape[0]):
121
+ pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
122
+ sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
123
+ prompt=None, n_candidates=n_candidates).cpu()
124
+ stories.append(pixels)
125
+
126
+ img = save_story_results(stories, video_len=4, n_candidates=n_candidates, mask=mask)
127
+ save_image(img, "gradio_demo_pororo.png", normalize=True)
128
+
129
+ return "gradio_demo_pororo.png"
130
+
131
+ with gr.Blocks(css='#output {width:750px; height:750px; float:left;}') as demo:
132
+ gr.Markdown('''
133
+ <p style="text-align: center;font-size:40px;"><b>StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation</b><br><font size="6">Adyasha Maharana, Darryl Hannan and Mohit Bansal (UNC Chapel Hill)<br>Published at <b>ECCV 2022</b></font></p>
134
+
135
+ StoryDALL-E \[1\] is a model trained for the task of Story Visualization \[2\].
136
+ The model receives a sequence of captions as input and generates a corresponding sequence of images which form a visual story depicting the narrative in the captions.
137
+ We modify this task to enable the model to receive an initial scene as input, which can be used as a cue for the setting of the story and also for generating unseen or low-resource visual elements. We refer to this task as Story Continuation \[1\].
138
+ StoryDALL-E is based on the [mega-dalle](https://github.com/borisdayma/dalle-mini) model and is adapted from the corresponding [PyTorch codebase](https://github.com/kuprel/min-dalle).
139
+ **This model has been developed for academic purposes only.**
140
+
141
+ \[[Paper](http://arxiv.org/abs/2209.06192)\] \[[Code](https://github.com/adymaharana/storydalle)\] \[[Model Card](https://github.com/adymaharana/storydalle/blob/main/MODEL_CARD.MD)\]
142
+
143
+ ### Dataset
144
+ This model has been trained using the Pororo story visualization dataset \[1\].
145
+ The data was adapted from the popular cartoon series *Pororo the Little Penguin* and originally released by \[2\].
146
+ The Pororo dataset contains 9 recurring characters, as shown below, in the decreasing order of their frequency in the training data.
147
+ <p align="center">
148
+ <img src="file/pororo_characters.png" width="800">
149
+ </p>
150
+ The training dataset contains nearly 10,000 samples in the training set. Most of the scenes occur in a snowy village, surrounded by hills, trees and houses. A few episodes are located in gardens or water bodies. All the captions are in the English language and predominantly contain verbs in the present tense. Additionally, the training of this model starts from the pretrained checkpoint of mega-dalle, which is trained on the Conceptual Captions dataset.
151
+
152
+ ### Intended Use
153
+ This model is intended for generating visual stories containing the 9 characters in the Pororo dataset. This version of the StoryDALL-E model is reasonable at the following scenarios:
154
+ * Frames containing a single character.
155
+ * Overtly visual actions such as *making cookies*, *walking*, *reading a book*, *sitting*.
156
+ * Scenes taking place in snowy settings, indoors and gardens.
157
+ * Visual stories contaning 1-3 characters across all frames.
158
+ * Scene transitions e.g. from day to night.
159
+ * Moderately capable of generating semantic concepts that do not appear in the story continuation dataset, such as *doughnut* and *lion*.
160
+
161
+ Here are some examples of generated visual stories for the above-mentioned settings.
162
+
163
+ <p align="center">
164
+ <img src="file/demo_pororo_good.png" width="1000">
165
+ </p>
166
+
167
+ Due to the small training dataset size for story visualization, the model has poor generalization to some unseen settings. The model struggles to generate coherent images in the following scenarios.
168
+ * Multiple characters in a frame.
169
+ * Non-visual actions such as *compliment*.
170
+ * Characters that are infrequent in the training dataset e.g. Rody, Harry.
171
+ * Background locations that are not found in the cartoon e.g. a busy city.
172
+ * Color-based descriptions for object.
173
+ * Completely new characters based on textual descriptions.
174
+
175
+ In the following demo, four or less captions can be entered in the `caption` text fields for the visual story.
176
+ Select a `source` frame based on the character that is predominant in your visual story.
177
+ `top_k` refers to the number of highest probability vocabulary tokens to keep for top-k-filtering.
178
+ Only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
179
+ Set `supercondition` to True to enable generation using a null hypothesis.
180
+ Select between 1-4 `n_candidates` to generate a diverse set of stories for the given captions.
181
+ <br><br>
182
+ Feel free to send feedback to [email protected].
183
+ ''')
184
+
185
+ with gr.Row():
186
+ with gr.Column():
187
+ caption_1 = gr.Textbox(label="Caption 1", value='Pororo is reading a book.')
188
+ caption_2 = gr.Textbox(label="Caption 2", value='Pororo is sleeping on the couch.')
189
+ caption_3 = gr.Textbox(label="Caption 3", value='Pororo wakes up in the middle of the night in his bed.')
190
+ caption_4 = gr.Textbox(label="Caption 4", value='Pororo is in his bedroom and looks terrified.')
191
+ source = gr.Radio(["Pororo", "Loopy", "Crong", "Poby", "Eddy", "Petty", "Tongtong", "Rody", "Harry"],
192
+ label="Source", value="Pororo")
193
+ top_k = gr.Slider(16, 128, label="top_k", value=32)
194
+ top_p = gr.Slider(0.01, 1.0, label="top_p", value=0.2)
195
+ supercondition = gr.Checkbox(value=False, label='supercondition')
196
+ n_candidates = gr.Dropdown([1, 2, 3, 4], value=4, label='n_candidates')
197
+
198
+ with gr.Row():
199
+ # clear_btn = gr.Button("Clear")
200
+ submit_btn = gr.Button("Submit")
201
+
202
+ with gr.Column():
203
+ with gr.Row():
204
+ frame_1_label = gr.Button("Frame 1")
205
+ frame_2_label = gr.Button("Frame 2")
206
+ frame_3_label = gr.Button("Frame 3")
207
+ frame_4_label = gr.Button("Frame 4")
208
+ # frame_1_label = gr.Label("Frame 1")
209
+ # frame_2_label = gr.Label("Frame 2")
210
+ # frame_3_label = gr.Label("Frame 3")
211
+ # frame_4_label = gr.Label("Frame 4")
212
+ output = gr.Image(label="", elem_id='output')
213
+
214
+ submit_btn.click(fn=predict,
215
+ inputs=[caption_1, caption_2, caption_3, caption_4, source, top_k, top_p, n_candidates,
216
+ supercondition], outputs=output)
217
+
218
+ gr.Markdown('''
219
+ ### References
220
+
221
+ \[1\] Maharana, Adyasha, et al. "StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation." ECCV. 2022.
222
+
223
+ \[2\] Li, Yitong, et al. "Storygan: A sequential conditional gan for story visualization." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
224
+
225
+ \[3\] Kim, Kyung-Min, et al. "DeepStory: video story QA by deep embedded memory networks." Proceedings of the 26th International Joint Conference on Artificial Intelligence. 2017.
226
+
227
+ \[4\] Sharma, Piyush, et al. "Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2018.
228
+ ''')
229
+
230
+ demo.launch(share=True)
231
+
232
+
233
+ if __name__ == "__main__":
234
+ args_list = ['--model_name_or_path', './ckpt/25.pth',
235
+ '--prefix_model_name_or_path', './1.3B/',
236
+ '--dataset_name', 'pororo',
237
+ '--tuning_mode', 'story',
238
+ '--preseqlen', '32',
239
+ '--condition',
240
+ '--story_len', '4',
241
+ '--sent_embed', '512',
242
+ '--prefix_dropout', '0.2',
243
+ '--data_dir', '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/',
244
+ '--dataloader_num_workers', '1',
245
+ '--do_eval',
246
+ '--per_gpu_eval_batch_size', '16',
247
+ '--mode', 'story']
248
+
249
+ parser = argparse.ArgumentParser(description='arguments for training/evaluating prefix-tuning DALLE')
250
+
251
+ # Model Arguments
252
+ parser.add_argument('--model_name_or_path', type=str, default=None,
253
+ help='The model checkpoint for weights initialization.')
254
+ parser.add_argument('--prefix_model_name_or_path', type=str, default=None,
255
+ help='The prefix model checkpoint for weights initialization.')
256
+ parser.add_argument('--prefix_mode', type=str, default='activation', help='activation or embedding')
257
+ parser.add_argument('--preseqlen', type=int, default=0, help='how many tokens of prefix should we include.')
258
+ parser.add_argument('--optim_prefix', action="store_true",
259
+ help='set to True if optimizing prefix directly; no if through amortized function')
260
+ parser.add_argument('--tuning_mode', type=str, default='prefixtune', help='prefixtune or finetune')
261
+ parser.add_argument('--top_k_layers', type=int, default=2,
262
+ help='In finetuning setting, if we only tune the top k layers.')
263
+ parser.add_argument('--parameterize_mode', type=str, default='mlp',
264
+ help="mlp or emb to parametrize when we optimize for the embeddings.")
265
+ parser.add_argument('--prefix_dropout', type=float, default=0.0, help='dropout rate for the prefix tuning model.')
266
+ parser.add_argument('--teacher_dropout', type=float, default=0.0, help='dropout rate for the teacher model.')
267
+ parser.add_argument('--init_random', action="store_true", help="set True if initializing random embeddings")
268
+ parser.add_argument('--init_shallow', action="store_true", help="set True if not using reparameterization")
269
+ parser.add_argument('--init_shallow_word', type=bool, default=False,
270
+ help="set True if init_shallow and specify words")
271
+ parser.add_argument('--replay_buffer', action="store_true", help="set True if using replay buffer in training")
272
+ parser.add_argument('--gumbel', action="store_true", help="set True if using the gumbel softmax in training")
273
+ parser.add_argument('--hidden_dim_prefix', type=float, default=512, help="hidden dim of MLP for generating prefix?")
274
+
275
+ # Data Arguments
276
+ parser.add_argument('--dataset_name', type=str, default='pororo', help="dataset name")
277
+ parser.add_argument('--data_dir', type=str, default=None, help="Path to data directory")
278
+ parser.add_argument('--lowdata_token', type=str, default='story',
279
+ help="The token to be prepended at initialization time.")
280
+ parser.add_argument('--use_lowdata_token', type=bool, default=True,
281
+ help="Whether we should use the lowdata token for prefix-tuning")
282
+ parser.add_argument('--train_embeddings', action="store_true", help="Whether to train word embeddings")
283
+ parser.add_argument('--train_max_target_length', type=int, default=100,
284
+ help='the max target length for training data.')
285
+ parser.add_argument('--val_max_target_length', type=int, default=100, help='the max target length for dev data.')
286
+ parser.add_argument('--dataloader_num_workers', type=int, default=8, help='number of workers when loading data')
287
+
288
+ # new arguments for story
289
+ parser.add_argument('--prompt', action="store_true", help="set True if using prompts in StoryDALLE")
290
+ parser.add_argument('--story_len', type=int, default=4, help='the max target length for dev data.')
291
+ parser.add_argument('--sent_embed', type=int, default=384, help='the max target length for dev data.')
292
+ parser.add_argument('--condition', action="store_true", help="set True if using prompts in StoryDALLE")
293
+ parser.add_argument('--clip_embed', action="store_true", help="set True if using prompts in StoryDALLE")
294
+
295
+ # Training Arguments
296
+ parser.add_argument('--output_dir', type=str, default=None, help="Path to data directory")
297
+ parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
298
+ parser.add_argument("--do_eval", action="store_true", help="Whether to run evaluation.")
299
+ parser.add_argument("--do_test", action="store_true", help="Whether to run test.")
300
+ parser.add_argument('--seed', type=int, default=42, help='seed for reproducibility')
301
+ parser.add_argument("--overwrite_output_dir", action="store_true", help="Whether to overwrite output dir.")
302
+ parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
303
+ parser.add_argument(
304
+ "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
305
+ )
306
+ parser.add_argument(
307
+ "--gradient_accumulation_steps",
308
+ type=int,
309
+ default=1,
310
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
311
+ )
312
+
313
+ parser.add_argument('--mode', type=str, default='val', help="mval or test.")
314
+
315
+ parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
316
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
317
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
318
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
319
+ parser.add_argument(
320
+ "--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
321
+ )
322
+ parser.add_argument(
323
+ "--max_steps",
324
+ default=-1,
325
+ type=int,
326
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
327
+ )
328
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
329
+ parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
330
+ parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
331
+ parser.add_argument(
332
+ "--eval_all_checkpoints",
333
+ action="store_true",
334
+ help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
335
+ )
336
+ parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
337
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
338
+ parser.add_argument(
339
+ "--fp16",
340
+ action="store_true",
341
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
342
+ )
343
+
344
+ parser.add_argument("--debug", action="store_true", help="Whether to debug the demo.")
345
+
346
+ args = parser.parse_args(args_list)
347
+
348
+ main(args)
349
+
350
+
351
+
352
+
353
+
dalle/__init__.py ADDED
File without changes
dalle/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (148 Bytes). View file
 
dalle/__pycache__/trainer_prefix.cpython-38.pyc ADDED
Binary file (52.7 kB). View file
 
dalle/models/__init__.py ADDED
@@ -0,0 +1,1462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import torch
9
+ import torch.nn as nn
10
+ import pytorch_lightning as pl
11
+ from typing import Optional, Tuple, Union
12
+ from omegaconf import OmegaConf
13
+ from torch.cuda.amp import autocast
14
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
15
+ from torch.nn import functional as F
16
+ from .stage1.vqgan import VQGAN
17
+ from .stage2.transformer import Transformer1d, iGPT
18
+ from .stage2.layers import Block
19
+ from .. import utils
20
+ from ..utils.config import get_base_config
21
+ from ..utils.sampling import sampling, sampling_igpt, get_positional_encoding, sampling_prefix, sampling_conditional
22
+ from ..utils.utils import save_image
23
+ from .tokenizer import build_tokenizer
24
+ import numpy as np
25
+ from .stage2.layers import CrossAttentionLayer
26
+
27
+ _MODELS = {
28
+ 'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
29
+ }
30
+
31
+ class Dalle(pl.LightningModule):
32
+ def __init__(self,
33
+ config: OmegaConf) -> None:
34
+ super().__init__()
35
+ self.tokenizer = None
36
+ self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
37
+ embed_dim=config.stage1.embed_dim,
38
+ hparams=config.stage1.hparams)
39
+ self.stage2 = Transformer1d(vocab_size_txt=config.stage2.vocab_size_txt,
40
+ vocab_size_img=config.stage2.vocab_size_img,
41
+ hparams=config.stage2.hparams)
42
+ self.config = config
43
+ self.config_stage1 = config.stage1
44
+ self.config_stage2 = config.stage2
45
+ self.config_dataset = config.dataset
46
+
47
+ # # make the parameters in stage 1 not trainable
48
+ # self.stage1.eval()
49
+ # for p in self.stage1.parameters():
50
+ # p.requires_grad = False
51
+
52
+ @classmethod
53
+ def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
54
+
55
+ path = args.model_name_or_path
56
+ config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
57
+ if args.do_train:
58
+ config_base = get_base_config('finetuning')
59
+ config_update = OmegaConf.merge(config_base, config_new)
60
+ for key, val in vars(args).items():
61
+ if key in config_update.optimizer.keys():
62
+ OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
63
+ if key in config_update.experiment.keys():
64
+ OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
65
+ else:
66
+ config_base = get_base_config('default')
67
+ config_update = OmegaConf.merge(config_base, config_new)
68
+
69
+ model = cls(config_update)
70
+ model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
71
+ context_length=model.config_dataset.context_length,
72
+ lowercase=True,
73
+ dropout=None)
74
+
75
+ print("Loading models from checkpoint %s" % path)
76
+
77
+ if hasattr(args, 'dalle_path') and args.dalle_path and args.dalle_path.endswith('.pth'):
78
+ model.load_state_dict(torch.load(args.dalle_path)["model_state_dict"])
79
+ else:
80
+ model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
81
+ model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
82
+
83
+ return model, config_update
84
+
85
+
86
+ @torch.no_grad()
87
+ def sampling(self,
88
+ prompt: Union[str, torch.LongTensor],
89
+ top_k: int = 256,
90
+ top_p: Optional[float] = None,
91
+ softmax_temperature: float = 1.0,
92
+ num_candidates: int = 96,
93
+ device: str = 'cuda:0',
94
+ use_fp16: bool = True) -> torch.FloatTensor:
95
+ self.stage1.eval()
96
+ self.stage2.eval()
97
+
98
+ if type(prompt) == str:
99
+ tokens = self.tokenizer.encode(prompt)
100
+ tokens = torch.LongTensor(tokens.ids)
101
+ else:
102
+ tokens = prompt
103
+ tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
104
+
105
+ # Check if the encoding works as intended
106
+ # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
107
+
108
+ tokens = tokens.to(device)
109
+ codes = sampling(self.stage2,
110
+ tokens,
111
+ top_k=top_k,
112
+ top_p=top_p,
113
+ softmax_temperature=softmax_temperature,
114
+ use_fp16=use_fp16)
115
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
116
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
117
+ return pixels
118
+
119
+ def forward(self,
120
+ images: torch.FloatTensor,
121
+ texts: Optional[torch.LongTensor],
122
+ past=None
123
+ ) -> tuple:
124
+ B, C, H, W = images.shape
125
+ with torch.no_grad():
126
+ with autocast(enabled=False):
127
+ codes = self.stage1.get_codes(images).detach()
128
+ pos_enc_tokens = get_positional_encoding(texts, mode='1d')
129
+ codes = codes.clone().detach()
130
+ pos_enc_code = get_positional_encoding(codes, mode='1d')
131
+ # codes = codes.unsqueeze(-1)
132
+ # pos_enc_code = pos_enc_code.unsqueeze(-1)
133
+ logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, past)
134
+ return logits_img, logits_txt, codes
135
+
136
+ def training_step(self, batch, batch_idx):
137
+ images, texts = batch
138
+ logits_img, logits_txt, codes = self(images, texts)
139
+
140
+ loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
141
+ loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
142
+ self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True)
143
+ self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True)
144
+ return loss_img + loss_txt
145
+
146
+ def validation_step(self, batch, batch_idx):
147
+ images, texts = batch
148
+ logits_img, logits_txt, codes = self(images, texts)
149
+ # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
150
+
151
+ loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
152
+ loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
153
+ self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True)
154
+ self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True)
155
+ return loss_img + loss_txt
156
+
157
+ def configure_optimizers(self):
158
+ assert self.config.optimizer.opt_type == 'adamW'
159
+ # assert self.config.optimizer.sched_type == 'cosine'
160
+
161
+ opt = torch.optim.AdamW(self.parameters(),
162
+ lr=self.config.optimizer.learning_rate,
163
+ betas=self.config.optimizer.betas,
164
+ weight_decay=self.config.optimizer.weight_decay)
165
+ # sched = CosineAnnealingLR(opt,
166
+ # T_max=self.config.optimizer.max_steps,
167
+ # eta_min=self.config.optimizer.min_lr)
168
+
169
+ def lr_lambda(current_step: int):
170
+ return max(
171
+ 0.0, float(self.config.optimizer.max_steps - current_step) / float(max(1, self.config.optimizer.max_steps))
172
+ )
173
+
174
+ sched = LambdaLR(opt, lr_lambda)
175
+ sched = {
176
+ 'scheduler': sched,
177
+ 'name': 'linear'
178
+ }
179
+ return [opt], [sched]
180
+
181
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
182
+ on_tpu=False, using_native_amp=False, using_lbfgs=False):
183
+ optimizer.step(closure=optimizer_closure)
184
+ self.lr_schedulers().step()
185
+ self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
186
+
187
+ def on_epoch_start(self):
188
+ self.stage1.eval()
189
+
190
+
191
+ class ImageGPT(pl.LightningModule):
192
+ def __init__(self,
193
+ config: OmegaConf) -> None:
194
+ super().__init__()
195
+ self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
196
+ embed_dim=config.stage1.embed_dim,
197
+ hparams=config.stage1.hparams)
198
+ self.stage2 = iGPT(vocab_size_img=config.stage2.vocab_size_img,
199
+ use_cls_cond=config.stage2.use_cls_cond,
200
+ hparams=config.stage2.hparams)
201
+ self.config = config
202
+ self.use_cls_cond = config.stage2.use_cls_cond
203
+
204
+ # make the parameters in stage 1 not trainable
205
+ self.stage1.eval()
206
+ for p in self.stage1.parameters():
207
+ p.requires_grad = False
208
+
209
+ @classmethod
210
+ def from_pretrained(cls,
211
+ path_upstream: str,
212
+ path_downstream: str) -> Tuple[nn.Module, OmegaConf]:
213
+ config_base = get_base_config(use_default=False)
214
+ config_down = OmegaConf.load(path_downstream)
215
+ config_down = OmegaConf.merge(config_base, config_down)
216
+
217
+ model = cls(config_down)
218
+ model.stage1.from_ckpt(os.path.join(path_upstream, 'stage1_last.ckpt'), strict=True)
219
+ model.stage2.from_ckpt(os.path.join(path_upstream, 'stage2_last.ckpt'), strict=False)
220
+ return model, config_down
221
+
222
+ def sample(self,
223
+ cls_idx: Optional[int] = None,
224
+ top_k: int = 256,
225
+ top_p: Optional[float] = None,
226
+ softmax_temperature: float = 1.0,
227
+ num_candidates: int = 16,
228
+ device: str = 'cuda:0',
229
+ use_fp16: bool = True,
230
+ is_tqdm: bool = True) -> torch.FloatTensor:
231
+ self.stage1.eval()
232
+ self.stage2.eval()
233
+
234
+ if cls_idx is None:
235
+ sos = self.stage2.sos.repeat(num_candidates, 1, 1)
236
+ else:
237
+ sos = torch.LongTensor([cls_idx]).to(device=device)
238
+ sos = sos.repeat(num_candidates)
239
+ sos = self.stage2.sos(sos).unsqueeze(1)
240
+
241
+ codes = sampling_igpt(self.stage2,
242
+ sos=sos,
243
+ top_k=top_k,
244
+ top_p=top_p,
245
+ softmax_temperature=softmax_temperature,
246
+ use_fp16=use_fp16,
247
+ is_tqdm=is_tqdm)
248
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
249
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
250
+ return pixels
251
+
252
+ def forward(self,
253
+ images: torch.FloatTensor,
254
+ labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
255
+ B, C, H, W = images.shape
256
+ with torch.no_grad():
257
+ with autocast(enabled=False):
258
+ codes = self.stage1.get_codes(images).detach()
259
+ logits = self.stage2(codes, labels)
260
+ return logits, codes
261
+
262
+ def training_step(self, batch, batch_idx):
263
+ images, labels = batch
264
+ logits, codes = self(images, labels=labels if self.use_cls_cond else None)
265
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
266
+ self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
267
+ return loss
268
+
269
+ def validation_step(self, batch, batch_idx):
270
+ images, labels = batch
271
+ logits, codes = self(images, labels=labels if self.use_cls_cond else None)
272
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
273
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
274
+ return loss
275
+
276
+ def configure_optimizers(self):
277
+ assert self.config.optimizer.opt_type == 'adamW'
278
+ assert self.config.optimizer.sched_type == 'cosine'
279
+
280
+ opt = torch.optim.AdamW(self.parameters(),
281
+ lr=self.config.optimizer.base_lr,
282
+ betas=self.config.optimizer.betas,
283
+ weight_decay=self.config.optimizer.weight_decay)
284
+ sched = CosineAnnealingLR(opt,
285
+ T_max=self.config.optimizer.max_steps,
286
+ eta_min=self.config.optimizer.min_lr)
287
+ sched = {
288
+ 'scheduler': sched,
289
+ 'name': 'cosine'
290
+ }
291
+ return [opt], [sched]
292
+
293
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
294
+ on_tpu=False, using_native_amp=False, using_lbfgs=False):
295
+ optimizer.step(closure=optimizer_closure)
296
+ self.lr_schedulers().step()
297
+ self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
298
+
299
+ def on_epoch_start(self):
300
+ self.stage1.eval()
301
+
302
+
303
+ class PromptDalle(Dalle):
304
+ """Classification Head for transformer encoders"""
305
+ def __init__(self, config):
306
+ super().__init__(config)
307
+ print('Initializing the PromptTuning model')
308
+
309
+ self.config = config
310
+ self.n_embd = config.stage2.hparams.embed_dim
311
+ self.preseqlen = config.prompt.preseqlen
312
+ self.prefix_dropout = config.prompt.prefix_dropout
313
+
314
+ # DIFFERENT PARAMETRIZATION:
315
+
316
+ print('[Full prompt-tuning Setting :) ]')
317
+ self.input_tokens = torch.arange(self.preseqlen).long()
318
+ self.wte = nn.Embedding(self.preseqlen, self.n_embd)
319
+ self.control_trans = nn.Sequential(
320
+ nn.Linear(self.n_embd, self.n_embd),
321
+ nn.Tanh(),
322
+ nn.Linear(self.n_embd, self.n_embd))
323
+ self.get_prompt = self.get_prompt_p5
324
+ self.dropout = nn.Dropout(self.prefix_dropout)
325
+
326
+ ###### NUM PARAMS #########
327
+ total_param = 0
328
+ for name, param in self.named_parameters():
329
+ # print(param.shape)
330
+ total_param += param.numel()
331
+ print('Total parameters is {}'.format(total_param))
332
+
333
+
334
+ @classmethod
335
+ def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
336
+
337
+ # if not args.model_name_or_path:
338
+ # args.model_name_or_path = args.prefix_model_name_or_path
339
+
340
+ path = args.prefix_model_name_or_path
341
+ path = _MODELS[path] if path in _MODELS else path
342
+ path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
343
+
344
+ config_base = get_base_config('prompt_tuning')
345
+ config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
346
+ config_update = OmegaConf.merge(config_base, config_new)
347
+
348
+ for key, val in vars(args).items():
349
+ if key in config_update.prompt.keys():
350
+ OmegaConf.update(config_update, "prompt.%s" % key, val, merge=False)
351
+ if key in config_update.optimizer.keys():
352
+ OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
353
+ if key in config_update.experiment.keys():
354
+ OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
355
+
356
+ model = cls(config_update)
357
+ model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
358
+ context_length=model.config_dataset.context_length,
359
+ lowercase=True,
360
+ dropout=None)
361
+
362
+ if args.model_name_or_path:
363
+ print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
364
+ # model.from_ckpt(args.model_name_or_path)
365
+ try:
366
+ model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
367
+ except KeyError:
368
+ model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
369
+
370
+ else:
371
+ print("Loading models from checkpoint %s" % path)
372
+ model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
373
+ model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
374
+
375
+ return model, config_update
376
+
377
+ def get_prompt_p5(self, bsz=None, eval=False):
378
+ input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
379
+ temp_control = self.wte(input_tokens)
380
+ past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
381
+ if not eval:
382
+ past_key_values = self.dropout(past_key_values)
383
+ return past_key_values
384
+
385
+ def forward(self,
386
+ images: torch.FloatTensor,
387
+ texts: Optional[torch.LongTensor],
388
+ **kwargs,
389
+ ):
390
+
391
+ #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
392
+
393
+ B, C, H, W = images.shape
394
+ prompt = self.get_prompt(bsz=B)
395
+ pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(B, -1).to(self.device), mode='1d')
396
+
397
+ # if self.mode_para == 2 and src_attn is not None and tgt_attn is not None:
398
+ # attention_mask = torch.cat([src_attn, tgt_attn], dim=1)
399
+
400
+
401
+ with torch.no_grad():
402
+ with autocast(enabled=False):
403
+ codes = self.stage1.get_codes(images).detach()
404
+
405
+ pos_enc_tokens = get_positional_encoding(texts, mode='1d')
406
+ codes = codes.clone().detach()
407
+ pos_enc_code = get_positional_encoding(codes, mode='1d')
408
+ # codes = codes.unsqueeze(-1)
409
+ # pos_enc_code = pos_enc_code.unsqueeze(-1)
410
+ # print(images.shape, codes.shape, texts.shape)
411
+ logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, prompt=prompt, pos_prompt=pos_enc_prompt)
412
+ return logits_img, logits_txt, codes
413
+
414
+
415
+ @torch.no_grad()
416
+ def sampling(self,
417
+ tokens: torch.LongTensor,
418
+ prompt: torch.FloatTensor,
419
+ top_k: int = 256,
420
+ top_p: Optional[float] = None,
421
+ softmax_temperature: float = 1.0,
422
+ num_candidates: int = 96,
423
+ device: str = 'cuda:0',
424
+ use_fp16: bool = True,
425
+ labels = None) -> torch.FloatTensor:
426
+ self.stage1.eval()
427
+ self.stage2.eval()
428
+
429
+ # tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
430
+
431
+ tokens = tokens.to(device)
432
+ pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(num_candidates, -1).to(self.device), mode='1d')
433
+
434
+ codes = sampling(self.stage2,
435
+ tokens,
436
+ top_k=top_k,
437
+ top_p=top_p,
438
+ softmax_temperature=softmax_temperature,
439
+ use_fp16=use_fp16,
440
+ prompt=prompt,
441
+ pos_prompt=pos_enc_prompt)
442
+
443
+ codes = codes.view(-1, 16, 16) # [B, 16, 16]
444
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
445
+ return pixels
446
+
447
+
448
+ @torch.no_grad()
449
+ def predict_step(self, batch, batch_idx, return_images=False):
450
+ orig_images, texts = batch
451
+
452
+ # extra for checks
453
+ logits_img, logits_txt, codes = self(orig_images, texts)
454
+ pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
455
+ bs = orig_images.shape[0]
456
+ pred = pred.view(bs, 16, 16) # [B, 16, 16]
457
+ pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
458
+ pixels = np.transpose(pixels, (0, 2, 3, 1))
459
+
460
+ # print(texts.shape, orig_images.shape)
461
+ prompt = self.get_prompt(bsz=5, eval=True)
462
+
463
+ images = []
464
+ for i, t in enumerate(texts):
465
+ pixels = self.sampling(t, prompt, top_k=16, num_candidates=5, labels=codes[i]).cpu().numpy()
466
+ pixels = np.transpose(pixels, (0, 2, 3, 1))
467
+ images.append(pixels)
468
+
469
+ if return_images:
470
+ return images
471
+ else:
472
+ save_image(orig_images, pixels, './out/images/pororo_prompt', batch_idx+10)
473
+ save_image(orig_images, images, './out/images/pororo_prompt', batch_idx)
474
+
475
+
476
+ class PrefixTuningDalle(Dalle):
477
+ """Classification Head for transformer encoders"""
478
+ def __init__(self, config):
479
+ super().__init__(config)
480
+ print('Initializing the PrefixTuning model')
481
+
482
+ self.config = config
483
+
484
+ self.match_n_layer = config.stage2.hparams.n_layers
485
+ self.match_n_head = config.stage2.hparams.n_heads
486
+ self.match_n_embd = config.stage2.hparams.embed_dim // config.stage2.hparams.n_heads
487
+ self.n_embd = config.stage2.hparams.embed_dim
488
+
489
+ self.optim_prefix = config.prefix.optim_prefix
490
+ self.preseqlen = config.prefix.preseqlen
491
+ self.prefix_dropout = config.prefix.prefix_dropout
492
+ self.init_random = config.prefix.init_random
493
+ self.hidden_dim_prefix = config.prefix.hidden_dim_prefix
494
+
495
+ self.lowdata_token = config.prefix.lowdata_token
496
+ self.init_shallow = config.prefix.init_shallow
497
+ self.init_shallow_word = config.prefix.init_shallow_word
498
+ self.mode_para = 0
499
+
500
+ print('PrefixTuning')
501
+ print('preseqlen is {}, optimizing the prefix directly'.format(self.preseqlen))
502
+
503
+ # DIFFERENT PARAMETRIZATION:
504
+
505
+ print('[Full prefix-tuning Setting :) ]')
506
+ self.input_tokens = torch.arange(self.preseqlen).long()
507
+ self.wte = nn.Embedding(self.preseqlen, self.n_embd)
508
+ self.control_trans = nn.Sequential(
509
+ nn.Linear(self.n_embd, self.hidden_dim_prefix),
510
+ nn.Tanh(),
511
+ nn.Linear(self.hidden_dim_prefix, self.match_n_layer * 2 * self.n_embd))
512
+ self.get_prompt = self.get_prompt_p5
513
+ self.dropout = nn.Dropout(self.prefix_dropout)
514
+
515
+ ###### NUM PARAMS #########
516
+ total_param = 0
517
+ for name, param in self.named_parameters():
518
+ # print(param.shape)
519
+ total_param += param.numel()
520
+ print('Total parameters is {}'.format(total_param))
521
+
522
+
523
+ @classmethod
524
+ def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
525
+
526
+ # if not args.model_name_or_path:
527
+ # args.model_name_or_path = args.prefix_model_name_or_path
528
+
529
+ path = args.prefix_model_name_or_path
530
+ path = _MODELS[path] if path in _MODELS else path
531
+ path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
532
+
533
+ config_base = get_base_config('prefixtuning')
534
+ config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
535
+ config_update = OmegaConf.merge(config_base, config_new)
536
+
537
+ for key, val in vars(args).items():
538
+ if key in config_update.prefix.keys():
539
+ OmegaConf.update(config_update, "prefix.%s" % key, val, merge=False)
540
+ if key in config_update.optimizer.keys():
541
+ OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
542
+ if key in config_update.experiment.keys():
543
+ OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
544
+
545
+ model = cls(config_update)
546
+ model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
547
+ context_length=model.config_dataset.context_length,
548
+ lowercase=True,
549
+ dropout=None)
550
+
551
+ if args.model_name_or_path:
552
+ print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
553
+ # model.from_ckpt(args.model_name_or_path)
554
+ try:
555
+ model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
556
+ except KeyError:
557
+ model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
558
+
559
+ else:
560
+ print("Loading models from checkpoint %s" % path)
561
+ model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
562
+ model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
563
+
564
+ return model, config_update
565
+
566
+ def get_prompt_p5(self, bsz=None, eval=False):
567
+ input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
568
+ temp_control = self.wte(input_tokens)
569
+ past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
570
+ bsz, seqlen, _ = past_key_values.shape
571
+ past_key_values = past_key_values.view(bsz, seqlen, self.match_n_layer * 2, self.match_n_head,
572
+ self.match_n_embd)
573
+ if not eval:
574
+ past_key_values = self.dropout(past_key_values)
575
+ # past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
576
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4])
577
+ # print(past_key_values.shape)
578
+ return past_key_values.split(2)
579
+
580
+ def forward(self,
581
+ images: torch.FloatTensor,
582
+ texts: Optional[torch.LongTensor],
583
+ **kwargs,
584
+ ):
585
+
586
+ #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
587
+
588
+ B, C, H, W = images.shape
589
+
590
+ if self.mode_para == 2:
591
+ past_key_values_prompt = self.get_prompt(bsz=B)
592
+ else:
593
+ past_key_values_prompt = self.get_prompt(bsz=B)
594
+
595
+ # if self.mode_para == 2 and src_attn is not None and tgt_attn is not None:
596
+ # attention_mask = torch.cat([src_attn, tgt_attn], dim=1)
597
+
598
+
599
+ with torch.no_grad():
600
+ with autocast(enabled=False):
601
+ codes = self.stage1.get_codes(images).detach()
602
+
603
+ pos_enc_tokens = get_positional_encoding(texts, mode='1d')
604
+ codes = codes.clone().detach()
605
+ pos_enc_code = get_positional_encoding(codes, mode='1d')
606
+ # codes = codes.unsqueeze(-1)
607
+ # pos_enc_code = pos_enc_code.unsqueeze(-1)
608
+ # print(images.shape, codes.shape, texts.shape)
609
+ logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, past_key_values_prompt)
610
+ return logits_img, logits_txt, codes
611
+
612
+ @torch.no_grad()
613
+ def sampling(self,
614
+ tokens: torch.LongTensor,
615
+ past: torch.FloatTensor,
616
+ top_k: int = 256,
617
+ top_p: Optional[float] = None,
618
+ softmax_temperature: float = 1.0,
619
+ num_candidates: int = 96,
620
+ device: str = 'cuda:0',
621
+ use_fp16: bool = True,
622
+ labels = None) -> torch.FloatTensor:
623
+ self.stage1.eval()
624
+ self.stage2.eval()
625
+
626
+ if len(past.shape) == 6:
627
+ n_layers, temp, bs, n_heads, seq_len, n_dim = past.shape
628
+ past = past.view(n_layers, temp, bs*n_heads, seq_len, n_dim)
629
+
630
+ tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
631
+
632
+ # Check if the encoding works as intended
633
+ # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
634
+
635
+ tokens = tokens.to(device)
636
+ codes = sampling_prefix(self.stage2,
637
+ tokens,
638
+ past,
639
+ top_k=top_k,
640
+ top_p=top_p,
641
+ softmax_temperature=softmax_temperature,
642
+ use_fp16=use_fp16,
643
+ labels = None if labels is None else labels.view(-1))
644
+
645
+ # codes = sampling(self.stage2,
646
+ # tokens,
647
+ # top_k=top_k,
648
+ # top_p=top_p,
649
+ # softmax_temperature=softmax_temperature,
650
+ # use_fp16=use_fp16)
651
+
652
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
653
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
654
+ return pixels
655
+
656
+ def training_step(self, batch, batch_idx):
657
+ images, texts = batch
658
+ logits_img, logits_txt, codes = self(images, texts)
659
+
660
+ loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
661
+ loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
662
+ self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True)
663
+ self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True)
664
+ return loss_img + loss_txt
665
+
666
+ def validation_step(self, batch, batch_idx):
667
+ images, texts = batch
668
+ logits_img, logits_txt, codes = self(images, texts)
669
+ # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
670
+
671
+ loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
672
+ loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
673
+ self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True)
674
+ self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True)
675
+ return loss_img + loss_txt
676
+
677
+ @torch.no_grad()
678
+ def predict_step(self, batch, batch_idx, return_images=False):
679
+ orig_images, texts = batch
680
+
681
+ # extra for checks
682
+ logits_img, logits_txt, codes = self(orig_images, texts)
683
+ pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
684
+ bs = orig_images.shape[0]
685
+ pred = pred.view(bs, 16, 16) # [B, 16, 16]
686
+ pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
687
+ pixels = np.transpose(pixels, (0, 2, 3, 1))
688
+
689
+
690
+ # print(texts.shape, orig_images.shape)
691
+ # concatenate the list of prompts (split by n_head) for better downstream processing
692
+ past_key_values_prompt = self.get_prompt(bsz=5, eval=True)
693
+ # print(past_key_values_prompt[0].shape, past_key_values_prompt[1].shape, len(past_key_values_prompt))
694
+ past_key_values_prompt = torch.cat([x.unsqueeze(0) for x in past_key_values_prompt], dim=0)
695
+ n_layers, temp, bs, n_heads, seq_len, n_dim = past_key_values_prompt.shape
696
+ past_key_values_prompt = past_key_values_prompt.view(n_layers, temp, bs*n_heads, seq_len, n_dim)
697
+ # print(past_key_values_prompt.shape)
698
+ images = []
699
+ for i, t in enumerate(texts):
700
+ pixels = self.sampling(t, past_key_values_prompt, top_k=16, num_candidates=5, labels=codes[i]).cpu().numpy()
701
+ pixels = np.transpose(pixels, (0, 2, 3, 1))
702
+ images.append(pixels)
703
+ # images.extend([p for p in pixels])
704
+ # print([i.shape for i in images])
705
+
706
+
707
+ if return_images:
708
+ return images
709
+ else:
710
+ save_image(orig_images, pixels, './out/images/pororo_prefix', batch_idx+10)
711
+ save_image(orig_images, images, './out/images/pororo_prefix', batch_idx)
712
+
713
+
714
+ class ConditionalDalle(Dalle):
715
+ """Classification Head for transformer encoders"""
716
+ def __init__(self, config):
717
+ super().__init__(config)
718
+ print('Initializing the Conditional Dalle model')
719
+
720
+ self.config = config
721
+
722
+ print('Setting up Cross-attention Layers')
723
+ self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams)
724
+
725
+ ###### NUM PARAMS #########
726
+ total_param = 0
727
+ for name, param in self.named_parameters():
728
+ # print(param.shape)
729
+ total_param += param.numel()
730
+ print('Total parameters is {}'.format(total_param))
731
+
732
+ @classmethod
733
+ def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
734
+
735
+ # if not args.model_name_or_path:
736
+ # args.model_name_or_path = args.prefix_model_name_or_path
737
+
738
+ path = args.model_name_or_path
739
+ config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
740
+ if args.do_train:
741
+ config_base = get_base_config('finetuning')
742
+ config_update = OmegaConf.merge(config_base, config_new)
743
+ for key, val in vars(args).items():
744
+ if key in config_update.optimizer.keys():
745
+ OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
746
+ if key in config_update.experiment.keys():
747
+ OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
748
+ else:
749
+ config_base = get_base_config('default')
750
+ config_update = OmegaConf.merge(config_base, config_new)
751
+
752
+ model = cls(config_update)
753
+ model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
754
+ context_length=model.config_dataset.context_length,
755
+ lowercase=True,
756
+ dropout=None)
757
+ print(model.cross_attention_idxs)
758
+ # print(next(model.cross_attention_layers[0].parameters()).is_cuda)
759
+
760
+ if args.dalle_path:
761
+ print("Loading model from pretrained checkpoint %s" % args.dalle_path)
762
+ # model.from_ckpt(args.model_name_or_path)
763
+ model.load_state_dict(torch.load(args.dalle_path)['model_state_dict'])
764
+ else:
765
+ print("Loading models from checkpoint %s" % path)
766
+ model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
767
+ model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
768
+
769
+ return model, config_update
770
+
771
+
772
+ def init_cross_attention(self, cross_attention_layers, hparams):
773
+ self.cross_attention_idxs = cross_attention_layers
774
+ self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
775
+ embed_dim=hparams.embed_dim,
776
+ n_heads=hparams.n_heads,
777
+ attn_bias=hparams.attn_bias,
778
+ resid_pdrop=hparams.resid_pdrop,
779
+ attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
780
+
781
+
782
+ def forward(self,
783
+ images: torch.FloatTensor,
784
+ src_images: Optional[torch.FloatTensor],
785
+ texts: Optional[torch.LongTensor],
786
+ **kwargs,
787
+ ):
788
+
789
+ #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
790
+
791
+ # print(images.shape, src_images.shape, texts.shape)
792
+ with torch.no_grad():
793
+ with autocast(enabled=False):
794
+ codes = self.stage1.get_codes(images).detach()
795
+ src_codes = self.stage1.get_codes(src_images).detach()
796
+
797
+ pos_enc_tokens = get_positional_encoding(texts, mode='1d')
798
+ codes = codes.clone().detach()
799
+ pos_enc_code = get_positional_encoding(codes, mode='1d')
800
+ src_codes = src_codes.clone().detach()
801
+ src_pos_enc_code = get_positional_encoding(src_codes, mode='1d')
802
+ # codes = codes.unsqueeze(-1)
803
+ # pos_enc_code = pos_enc_code.unsqueeze(-1)
804
+ # print(images.shape, codes.shape, texts.shape)
805
+ logits_img, logits_txt = self.stage2.forward_with_context(codes, texts,
806
+ pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code,
807
+ self.cross_attention_idxs, self.cross_attention_layers)
808
+ # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
809
+ return logits_img, logits_txt, codes
810
+
811
+ @torch.no_grad()
812
+ def sampling(self,
813
+ prompt: torch.LongTensor,
814
+ source: torch.FloatTensor,
815
+ top_k: int = 256,
816
+ top_p: Optional[float] = None,
817
+ softmax_temperature: float = 1.0,
818
+ num_candidates: int = 96,
819
+ device: str = 'cuda:0',
820
+ use_fp16: bool = True) -> torch.FloatTensor:
821
+ self.stage1.eval()
822
+ self.stage2.eval()
823
+
824
+ if type(prompt) == str:
825
+ tokens = self.tokenizer.encode(prompt)
826
+ tokens = torch.LongTensor(tokens.ids)
827
+ else:
828
+ tokens = prompt
829
+
830
+ tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
831
+
832
+ # Check if the encoding works as intended
833
+ # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
834
+
835
+ tokens = tokens.to(device)
836
+ source = source.to(device)
837
+
838
+ with autocast(enabled=False):
839
+ src_codes = self.stage1.get_codes(source).detach()
840
+ src_codes = torch.repeat_interleave(src_codes, num_candidates, dim=0)
841
+
842
+ codes = sampling_conditional(self.stage2,
843
+ self.cross_attention_idxs,
844
+ self.cross_attention_layers,
845
+ tokens,
846
+ src_codes,
847
+ top_k=top_k,
848
+ top_p=top_p,
849
+ softmax_temperature=softmax_temperature,
850
+ use_fp16=use_fp16)
851
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
852
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
853
+ return pixels
854
+
855
+ def training_step(self, batch, batch_idx):
856
+ images, texts = batch
857
+ logits_img, logits_txt, codes = self(images, texts)
858
+
859
+ loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
860
+ loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
861
+ self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True)
862
+ self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True)
863
+ return loss_img + loss_txt
864
+
865
+ def validation_step(self, batch, batch_idx):
866
+ images, texts = batch
867
+ logits_img, logits_txt, codes = self(images, texts)
868
+ # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
869
+
870
+ loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
871
+ loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
872
+ self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True)
873
+ self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True)
874
+ return loss_img + loss_txt
875
+
876
+ @torch.no_grad()
877
+ def predict_step(self, batch, batch_idx):
878
+ orig_images, texts = batch
879
+ # concatenate the list of prompts (split by n_head) for better downstream processing
880
+ past_key_values_prompt = self.get_prompt(bsz=5)
881
+ past_key_values_prompt = torch.cat([x.unsqueeze(0) for x in past_key_values_prompt], dim=0)
882
+ images = []
883
+ for t in texts:
884
+ pixels = self.sampling(t, past_key_values_prompt, top_k=64, num_candidates=5).cpu().numpy()
885
+ pixels = np.transpose(pixels, (0, 2, 3, 1))
886
+ images.append(pixels)
887
+ # images.extend([p for p in pixels])
888
+ # print([i.shape for i in images])
889
+
890
+ save_image(orig_images, images, './out/images/', batch_idx)
891
+
892
+
893
+ class PromptConditionalDalle(Dalle):
894
+ """Classification Head for transformer encoders"""
895
+ def __init__(self, config):
896
+ super().__init__(config)
897
+ print('Initializing the Conditional Dalle model')
898
+
899
+ self.config = config
900
+
901
+ print('Setting up Cross-attention Layers')
902
+ self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams)
903
+
904
+ self.n_embd = config.stage2.hparams.embed_dim
905
+ self.preseqlen = config.story.preseqlen
906
+ self.prefix_dropout = config.story.prefix_dropout
907
+
908
+ # DIFFERENT PARAMETRIZATION:
909
+
910
+ print('[Full prompt-tuning Setting :) ]')
911
+ self.input_tokens = torch.arange(self.preseqlen).long()
912
+ self.wte = nn.Embedding(self.preseqlen, self.n_embd)
913
+ self.control_trans = nn.Sequential(
914
+ nn.Linear(self.n_embd, self.n_embd),
915
+ nn.Tanh(),
916
+ nn.Linear(self.n_embd, self.n_embd))
917
+ self.get_prompt = self.get_prompt_p5
918
+ self.dropout = nn.Dropout(self.prefix_dropout)
919
+
920
+ ###### NUM PARAMS #########
921
+ total_param = 0
922
+ for name, param in self.named_parameters():
923
+ # print(param.shape)
924
+ total_param += param.numel()
925
+ print('Total parameters is {}'.format(total_param))
926
+
927
+ @classmethod
928
+ def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
929
+
930
+ # if not args.model_name_or_path:
931
+ # args.model_name_or_path = args.prefix_model_name_or_path
932
+
933
+ path = args.prefix_model_name_or_path
934
+ path = _MODELS[path] if path in _MODELS else path
935
+ path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
936
+
937
+ config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
938
+ if args.do_train:
939
+ config_base = get_base_config('story')
940
+ config_update = OmegaConf.merge(config_base, config_new)
941
+ for key, val in vars(args).items():
942
+ if key in config_update.story.keys():
943
+ OmegaConf.update(config_update, "story.%s" % key, val, merge=False)
944
+ if key in config_update.optimizer.keys():
945
+ OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
946
+ if key in config_update.experiment.keys():
947
+ OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
948
+ else:
949
+ config_base = get_base_config('default')
950
+ config_update = OmegaConf.merge(config_base, config_new)
951
+
952
+ model = cls(config_update)
953
+ model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
954
+ context_length=model.config_dataset.context_length,
955
+ lowercase=True,
956
+ dropout=None)
957
+ print(model.cross_attention_idxs)
958
+ # print(next(model.cross_attention_layers[0].parameters()).is_cuda)
959
+
960
+ if args.model_name_or_path:
961
+ print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
962
+ # model.from_ckpt(args.model_name_or_path)
963
+ try:
964
+ model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
965
+ except KeyError:
966
+ model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
967
+
968
+ else:
969
+ print("Loading models from checkpoint %s" % path)
970
+ model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
971
+ model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
972
+
973
+ return model, config_update
974
+
975
+
976
+ def init_cross_attention(self, cross_attention_layers, hparams):
977
+ self.cross_attention_idxs = cross_attention_layers
978
+ self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
979
+ embed_dim=hparams.embed_dim,
980
+ n_heads=hparams.n_heads,
981
+ attn_bias=hparams.attn_bias,
982
+ resid_pdrop=hparams.resid_pdrop,
983
+ attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
984
+
985
+ def get_prompt_p5(self, bsz=None, eval=False):
986
+ input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
987
+ temp_control = self.wte(input_tokens)
988
+ past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
989
+ if not eval:
990
+ past_key_values = self.dropout(past_key_values)
991
+ return past_key_values
992
+
993
+ def forward(self,
994
+ images: torch.FloatTensor,
995
+ src_images: Optional[torch.FloatTensor],
996
+ texts: Optional[torch.LongTensor],
997
+ **kwargs,
998
+ ):
999
+
1000
+ #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
1001
+
1002
+ # print(images.shape, src_images.shape, texts.shape)
1003
+ with torch.no_grad():
1004
+ with autocast(enabled=False):
1005
+ codes = self.stage1.get_codes(images).detach()
1006
+ src_codes = self.stage1.get_codes(src_images).detach()
1007
+
1008
+ B, C, H, W = images.shape
1009
+ prompt = self.get_prompt(bsz=B)
1010
+ pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(B, -1).to(self.device), mode='1d')
1011
+
1012
+ pos_enc_tokens = get_positional_encoding(texts, mode='1d')
1013
+ codes = codes.clone().detach()
1014
+ pos_enc_code = get_positional_encoding(codes, mode='1d')
1015
+ src_codes = src_codes.clone().detach()
1016
+ src_pos_enc_code = get_positional_encoding(src_codes, mode='1d')
1017
+ # codes = codes.unsqueeze(-1)
1018
+ # pos_enc_code = pos_enc_code.unsqueeze(-1)
1019
+ # print(images.shape, codes.shape, texts.shape)
1020
+ logits_img, logits_txt = self.stage2.forward_with_context(codes, texts,
1021
+ pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code,
1022
+ self.cross_attention_idxs, self.cross_attention_layers,
1023
+ prompt=prompt, pos_prompt=pos_enc_prompt)
1024
+ # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
1025
+ return logits_img, logits_txt, codes
1026
+
1027
+ @torch.no_grad()
1028
+ def sampling(self,
1029
+ tokens: torch.LongTensor,
1030
+ prompt: torch.LongTensor,
1031
+ source: torch.FloatTensor,
1032
+ top_k: int = 256,
1033
+ top_p: Optional[float] = None,
1034
+ softmax_temperature: float = 1.0,
1035
+ num_candidates: int = 96,
1036
+ device: str = 'cuda:0',
1037
+ use_fp16: bool = True,
1038
+ labels=None) -> torch.FloatTensor:
1039
+
1040
+ self.stage1.eval()
1041
+ self.stage2.eval()
1042
+
1043
+ if type(tokens) == str:
1044
+ tokens = self.tokenizer.encode(prompt)
1045
+ tokens = torch.LongTensor(tokens.ids)
1046
+ else:
1047
+ pass
1048
+
1049
+ tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
1050
+
1051
+ # Check if the encoding works as intended
1052
+ # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
1053
+
1054
+ tokens = tokens.to(device)
1055
+ source = source.to(device)
1056
+
1057
+ pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(num_candidates, -1).to(self.device), mode='1d')
1058
+
1059
+ with autocast(enabled=False):
1060
+ src_codes = self.stage1.get_codes(source).detach()
1061
+ src_codes = torch.repeat_interleave(src_codes, num_candidates, dim=0)
1062
+
1063
+ codes = sampling_conditional(self.stage2,
1064
+ self.cross_attention_idxs,
1065
+ self.cross_attention_layers,
1066
+ tokens,
1067
+ src_codes,
1068
+ top_k=top_k,
1069
+ top_p=top_p,
1070
+ softmax_temperature=softmax_temperature,
1071
+ use_fp16=use_fp16,
1072
+ prompt=prompt,
1073
+ pos_prompt=pos_enc_prompt)
1074
+
1075
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
1076
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
1077
+ return pixels
1078
+
1079
+
1080
+ @torch.no_grad()
1081
+ def predict_step(self, batch, batch_idx, return_images=False):
1082
+ orig_images, texts = batch
1083
+ # concatenate the list of prompts (split by n_head) for better downstream processing
1084
+
1085
+ # extra for checks
1086
+ logits_img, logits_txt, codes = self(orig_images, texts)
1087
+ pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
1088
+ bs = orig_images.shape[0]
1089
+ pred = pred.view(bs, 16, 16) # [B, 16, 16]
1090
+ pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
1091
+ pixels = np.transpose(pixels, (0, 2, 3, 1))
1092
+
1093
+ prompt = self.get_prompt(bsz=5, eval=True)
1094
+
1095
+ images = []
1096
+ for t in texts:
1097
+ pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
1098
+ pixels = np.transpose(pixels, (0, 2, 3, 1))
1099
+ images.append(pixels)
1100
+ # images.extend([p for p in pixels])
1101
+ # print([i.shape for i in images])
1102
+
1103
+ if return_images:
1104
+ return images
1105
+ else:
1106
+ save_image(orig_images, pixels, './out/images/pororo_story', batch_idx+10)
1107
+ save_image(orig_images, images, './out/images/pororo_story', batch_idx)
1108
+
1109
+
1110
+ class StoryDalle(Dalle):
1111
+ """Base model with story block"""
1112
+ def __init__(self, config):
1113
+ super().__init__(config)
1114
+ print('Initializing the Conditional Dalle model')
1115
+
1116
+ self.config = config
1117
+
1118
+ self.story_linear = nn.Linear(config.story.sent_embed, config.stage2.hparams.embed_dim)
1119
+ self.story_block = Block(ctx_len=config.story.story_len,
1120
+ embed_dim=config.stage2.hparams.embed_dim,
1121
+ n_heads=config.stage2.hparams.n_heads,
1122
+ mlp_bias=config.stage2.hparams.mlp_bias,
1123
+ attn_bias=config.stage2.hparams.attn_bias,
1124
+ resid_pdrop=config.stage2.hparams.resid_pdrop,
1125
+ attn_pdrop=config.stage2.hparams.attn_pdrop,
1126
+ gelu_use_approx=config.stage2.hparams.gelu_use_approx)
1127
+
1128
+ if self.config.story.prompt:
1129
+ self.n_embd = config.stage2.hparams.embed_dim
1130
+ self.preseqlen = config.story.preseqlen
1131
+ self.prefix_dropout = config.story.prefix_dropout
1132
+
1133
+ # DIFFERENT PARAMETRIZATION:
1134
+
1135
+ print('[Full prompt-tuning Setting :) ]')
1136
+ self.input_tokens = torch.arange(self.preseqlen).long()
1137
+ self.wte = nn.Embedding(self.preseqlen, self.n_embd)
1138
+ self.control_trans = nn.Sequential(
1139
+ nn.Linear(self.n_embd, self.n_embd),
1140
+ nn.Tanh(),
1141
+ nn.Linear(self.n_embd, self.n_embd))
1142
+ self.get_prompt = self.get_prompt_p5
1143
+ self.dropout = nn.Dropout(self.prefix_dropout)
1144
+
1145
+ if self.config.story.condition:
1146
+ print('Setting up Cross-attention Layers')
1147
+ self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams)
1148
+
1149
+ ###### NUM PARAMS #########
1150
+ total_param = 0
1151
+ for name, param in self.named_parameters():
1152
+ # print(param.shape)
1153
+ total_param += param.numel()
1154
+ print('Total parameters is {}'.format(total_param))
1155
+
1156
+ @classmethod
1157
+ def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
1158
+
1159
+ # if not args.model_name_or_path:
1160
+ # args.model_name_or_path = args.prefix_model_name_or_path
1161
+
1162
+ path = args.prefix_model_name_or_path
1163
+ path = _MODELS[path] if path in _MODELS else path
1164
+ path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
1165
+
1166
+ config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
1167
+ # if args.do_train:
1168
+ config_base = get_base_config('story')
1169
+ config_update = OmegaConf.merge(config_base, config_new)
1170
+ for key, val in vars(args).items():
1171
+ if key in config_update.story.keys():
1172
+ OmegaConf.update(config_update, "story.%s" % key, val, merge=False)
1173
+ if key in config_update.optimizer.keys():
1174
+ OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
1175
+ if key in config_update.experiment.keys():
1176
+ OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
1177
+ # else:
1178
+ # config_base = get_base_config('story')
1179
+ # config_update = OmegaConf.merge(config_base, config_new)
1180
+ # print(next(model.cross_attention_layers[0].parameters()).is_cuda)
1181
+
1182
+ if args.model_name_or_path:
1183
+ if 'pororo' in args.model_name_or_path:
1184
+ config_update.stage2.vocab_size_txt = config_update.stage2.vocab_size_txt + 9
1185
+ elif 'flintstones' in args.model_name_or_path:
1186
+ config_update.stage2.vocab_size_txt = config_update.stage2.vocab_size_txt + 7
1187
+ model = cls(config_update)
1188
+ model_dir = os.path.dirname(args.model_name_or_path)
1189
+ print(model_dir)
1190
+ model.tokenizer = build_tokenizer(model_dir,
1191
+ context_length=model.config_dataset.context_length,
1192
+ lowercase=True,
1193
+ dropout=None)
1194
+ print("Loaded tokenizer from finetuned checkpoint")
1195
+ print(model.cross_attention_idxs)
1196
+ print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
1197
+ # model.from_ckpt(args.model_name_or_path)
1198
+ try:
1199
+ model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
1200
+ except KeyError:
1201
+ model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
1202
+ else:
1203
+ model = cls(config_update)
1204
+ print(model.cross_attention_idxs)
1205
+ print("Loading models from checkpoint %s" % path)
1206
+ model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
1207
+ model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
1208
+
1209
+ model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
1210
+ context_length=model.config_dataset.context_length,
1211
+ lowercase=True,
1212
+ dropout=None)
1213
+
1214
+
1215
+ return model, config_update
1216
+
1217
+
1218
+ def init_cross_attention(self, cross_attention_layers, hparams):
1219
+ self.cross_attention_idxs = cross_attention_layers
1220
+ self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
1221
+ embed_dim=hparams.embed_dim,
1222
+ n_heads=hparams.n_heads,
1223
+ attn_bias=hparams.attn_bias,
1224
+ resid_pdrop=hparams.resid_pdrop,
1225
+ attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
1226
+
1227
+ def get_prompt_p5(self, bsz=None, eval=False):
1228
+ input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
1229
+ temp_control = self.wte(input_tokens)
1230
+ past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
1231
+ if not eval:
1232
+ past_key_values = self.dropout(past_key_values)
1233
+ return past_key_values
1234
+
1235
+ def forward(self,
1236
+ images: torch.FloatTensor,
1237
+ src_images: Optional[torch.FloatTensor],
1238
+ texts: Optional[torch.LongTensor],
1239
+ sent_embeds: Optional[torch.FloatTensor],
1240
+ **kwargs,
1241
+ ):
1242
+
1243
+ # print(images.shape, src_images.shape, texts.shape, sent_embeds.shape)
1244
+
1245
+ B, L, C, H, W = images.shape
1246
+ images = images.view(B*L, C, H, W)
1247
+ src_images = src_images.unsqueeze(1).expand(-1, L, -1, -1, -1).reshape(B*L, C, H, W)
1248
+ sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1)
1249
+ texts = texts.view(B * L, -1)
1250
+
1251
+ #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
1252
+
1253
+ with torch.no_grad():
1254
+ with autocast(enabled=False):
1255
+ codes = self.stage1.get_codes(images).detach()
1256
+ src_codes = self.stage1.get_codes(src_images).detach()
1257
+
1258
+ B, C, H, W = images.shape
1259
+
1260
+ if self.config.story.prompt:
1261
+ prompt = self.get_prompt(bsz=B)
1262
+ prompt = torch.cat([prompt, sent_embeds], dim=1)
1263
+ else:
1264
+ prompt = sent_embeds
1265
+
1266
+ # dim = 0 for full-model finetuning??
1267
+ pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B, -1).to(self.device),
1268
+ mode='1d')
1269
+
1270
+ pos_enc_tokens = get_positional_encoding(texts, mode='1d')
1271
+ codes = codes.clone().detach()
1272
+ pos_enc_code = get_positional_encoding(codes, mode='1d')
1273
+ src_codes = src_codes.clone().detach()
1274
+ src_pos_enc_code = get_positional_encoding(src_codes, mode='1d')
1275
+ # codes = codes.unsqueeze(-1)
1276
+ # pos_enc_code = pos_enc_code.unsqueeze(-1)
1277
+ # print(images.shape, codes.shape, texts.shape)
1278
+ if self.config.story.condition:
1279
+ logits_img, logits_txt = self.stage2.forward_with_context(codes, texts,
1280
+ pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code,
1281
+ self.cross_attention_idxs, self.cross_attention_layers,
1282
+ prompt=prompt, pos_prompt=pos_enc_prompt)
1283
+ else:
1284
+ logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, prompt=prompt,
1285
+ pos_prompt=pos_enc_prompt)
1286
+
1287
+ # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
1288
+ return logits_img, logits_txt, codes
1289
+
1290
+ @torch.no_grad()
1291
+ def sampling(self,
1292
+ tokens: torch.LongTensor,
1293
+ source: torch.FloatTensor,
1294
+ sent_embeds: torch.FloatTensor,
1295
+ top_k: int = 256,
1296
+ top_p: Optional[float] = None,
1297
+ softmax_temperature: float = 1.0,
1298
+ num_candidates: int = 96,
1299
+ device: str = 'cuda:0',
1300
+ use_fp16: bool = True,
1301
+ labels=None,
1302
+ prompt = None) -> torch.FloatTensor:
1303
+
1304
+ self.stage1.eval()
1305
+ self.stage2.eval()
1306
+
1307
+ if type(tokens) == str:
1308
+ tokens = self.tokenizer.encode(tokens)
1309
+ tokens = torch.LongTensor(tokens.ids)
1310
+
1311
+ # tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
1312
+
1313
+ # Check if the encoding works as intended
1314
+ # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
1315
+
1316
+ tokens = tokens.to(device)
1317
+ source = source.to(device)
1318
+
1319
+ # print(tokens.shape, sent_embeds.shape, prompt.shape)
1320
+ B, L, _ = sent_embeds.shape
1321
+ sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1)
1322
+ if prompt is not None:
1323
+ prompt = torch.cat([prompt, sent_embeds], dim=1)
1324
+ else:
1325
+ prompt = sent_embeds
1326
+ pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B*L, -1).to(self.device), mode='1d')
1327
+
1328
+ with autocast(enabled=False):
1329
+ src_codes = self.stage1.get_codes(source).detach()
1330
+ src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
1331
+ print(tokens.shape, src_codes.shape, prompt.shape)
1332
+ if self.config.story.condition:
1333
+ codes = sampling_conditional(self.stage2,
1334
+ self.cross_attention_idxs,
1335
+ self.cross_attention_layers,
1336
+ tokens,
1337
+ src_codes,
1338
+ top_k=top_k,
1339
+ top_p=top_p,
1340
+ softmax_temperature=softmax_temperature,
1341
+ use_fp16=use_fp16,
1342
+ prompt=prompt,
1343
+ pos_prompt=pos_enc_prompt)
1344
+ else:
1345
+ codes = sampling(self.stage2,
1346
+ tokens,
1347
+ top_k=top_k,
1348
+ top_p=top_p,
1349
+ softmax_temperature=softmax_temperature,
1350
+ use_fp16=use_fp16,
1351
+ prompt=prompt,
1352
+ pos_prompt=pos_enc_prompt)
1353
+
1354
+ codes = codes.view(self.config.story.story_len, 16, 16) # [B, 16, 16]
1355
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
1356
+ return pixels
1357
+
1358
+ @torch.no_grad()
1359
+ def sampling_batch(self,
1360
+ tokens: torch.LongTensor,
1361
+ source: torch.FloatTensor,
1362
+ sent_embeds: torch.FloatTensor,
1363
+ top_k: int = 256,
1364
+ top_p: Optional[float] = None,
1365
+ softmax_temperature: float = 1.0,
1366
+ num_candidates: int = 96,
1367
+ device: str = 'cuda:0',
1368
+ use_fp16: bool = True,
1369
+ labels=None,
1370
+ prompt=None, n_candidates=1) -> torch.FloatTensor:
1371
+
1372
+ self.stage1.eval()
1373
+ self.stage2.eval()
1374
+
1375
+ if type(tokens) == str:
1376
+ tokens = self.tokenizer.encode(tokens)
1377
+ tokens = torch.LongTensor(tokens.ids)
1378
+
1379
+ # tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
1380
+
1381
+ # Check if the encoding works as intended
1382
+ # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
1383
+
1384
+ tokens = tokens.to(device)
1385
+ source = source.to(device)
1386
+
1387
+ # print(tokens.shape, sent_embeds.shape, prompt.shape)
1388
+ B, L, _ = sent_embeds.shape
1389
+ sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1)
1390
+ if prompt is not None:
1391
+ prompt = torch.cat([prompt, sent_embeds], dim=1)
1392
+ else:
1393
+ prompt = sent_embeds
1394
+ pos_enc_prompt = get_positional_encoding(
1395
+ torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B * L, -1).to(self.device), mode='1d')
1396
+
1397
+ with autocast(enabled=False):
1398
+ src_codes = self.stage1.get_codes(source).detach()
1399
+
1400
+ # repeat inputs to adjust to n_candidates and story length
1401
+ src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
1402
+ prompt = prompt.repeat(n_candidates, 1, 1)
1403
+ pos_enc_prompt = pos_enc_prompt.repeat(n_candidates, 1)
1404
+ tokens = tokens.repeat(n_candidates, 1)
1405
+ print(tokens.shape, src_codes.shape, prompt.shape, pos_enc_prompt.shape)
1406
+ if self.config.story.condition:
1407
+ codes = sampling_conditional(self.stage2,
1408
+ self.cross_attention_idxs,
1409
+ self.cross_attention_layers,
1410
+ tokens,
1411
+ src_codes,
1412
+ top_k=top_k,
1413
+ top_p=top_p,
1414
+ softmax_temperature=softmax_temperature,
1415
+ use_fp16=use_fp16,
1416
+ prompt=prompt,
1417
+ pos_prompt=pos_enc_prompt)
1418
+ else:
1419
+ codes = sampling(self.stage2,
1420
+ tokens,
1421
+ top_k=top_k,
1422
+ top_p=top_p,
1423
+ softmax_temperature=softmax_temperature,
1424
+ use_fp16=use_fp16,
1425
+ prompt=prompt,
1426
+ pos_prompt=pos_enc_prompt)
1427
+
1428
+ codes = codes.view(self.config.story.story_len * n_candidates, 16, 16) # [B, 16, 16]
1429
+ print(codes.shape)
1430
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 3, 256, 256]
1431
+ print(pixels.shape)
1432
+ return pixels.view(n_candidates, self.config.story.story_len, pixels.shape[-3], pixels.shape[-2], pixels.shape[-1])
1433
+
1434
+
1435
+ @torch.no_grad()
1436
+ def predict_step(self, batch, batch_idx, return_images=False):
1437
+ orig_images, texts = batch
1438
+ # concatenate the list of prompts (split by n_head) for better downstream processing
1439
+
1440
+ # extra for checks
1441
+ logits_img, logits_txt, codes = self(orig_images, texts)
1442
+ pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
1443
+ bs = orig_images.shape[0]
1444
+ pred = pred.view(bs, 16, 16) # [B, 16, 16]
1445
+ pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
1446
+ pixels = np.transpose(pixels, (0, 2, 3, 1))
1447
+
1448
+ prompt = self.get_prompt(bsz=5, eval=True)
1449
+
1450
+ images = []
1451
+ for t in texts:
1452
+ pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
1453
+ pixels = np.transpose(pixels, (0, 2, 3, 1))
1454
+ images.append(pixels)
1455
+ # images.extend([p for p in pixels])
1456
+ # print([i.shape for i in images])
1457
+
1458
+ if return_images:
1459
+ return images
1460
+ else:
1461
+ save_image(orig_images, pixels, './out/images/pororo_story', batch_idx+10)
1462
+ save_image(orig_images, images, './out/images/pororo_story', batch_idx)
dalle/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (34.9 kB). View file
 
dalle/models/__pycache__/prefix_tuning_model.cpython-38.pyc ADDED
Binary file (5.05 kB). View file
 
dalle/models/__pycache__/tokenizer.cpython-38.pyc ADDED
Binary file (974 Bytes). View file
 
dalle/models/stage1/__pycache__/layers.cpython-38.pyc ADDED
Binary file (7.85 kB). View file
 
dalle/models/stage1/__pycache__/vqgan.cpython-38.pyc ADDED
Binary file (4.04 kB). View file
 
dalle/models/stage1/layers.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Tuple, Optional
9
+
10
+
11
+ def nonlinearity(x):
12
+ # swish
13
+ return x*torch.sigmoid(x)
14
+
15
+
16
+ def Normalize(in_channels):
17
+ return torch.nn.GroupNorm(num_groups=32,
18
+ num_channels=in_channels,
19
+ eps=1e-6,
20
+ affine=True)
21
+
22
+
23
+ class Upsample(nn.Module):
24
+ def __init__(self, in_channels, with_conv):
25
+ super().__init__()
26
+ self.with_conv = with_conv
27
+ if self.with_conv:
28
+ self.conv = torch.nn.Conv2d(in_channels,
29
+ in_channels,
30
+ kernel_size=3,
31
+ stride=1,
32
+ padding=1)
33
+
34
+ def forward(self, x):
35
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
36
+ if self.with_conv:
37
+ x = self.conv(x)
38
+ return x
39
+
40
+
41
+ class Downsample(nn.Module):
42
+ def __init__(self, in_channels, with_conv):
43
+ super().__init__()
44
+ self.with_conv = with_conv
45
+ if self.with_conv:
46
+ # no asymmetric padding in torch conv, must do it ourselves
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=2,
51
+ padding=0)
52
+
53
+ def forward(self, x):
54
+ if self.with_conv:
55
+ pad = (0, 1, 0, 1)
56
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
57
+ x = self.conv(x)
58
+ else:
59
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
60
+ return x
61
+
62
+
63
+ class ResnetBlock(nn.Module):
64
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
65
+ dropout, temb_channels=512):
66
+ assert temb_channels == 0
67
+ super().__init__()
68
+ self.in_channels = in_channels
69
+ out_channels = in_channels if out_channels is None else out_channels
70
+ self.out_channels = out_channels
71
+ self.use_conv_shortcut = conv_shortcut
72
+
73
+ self.norm1 = Normalize(in_channels)
74
+ self.conv1 = torch.nn.Conv2d(in_channels,
75
+ out_channels,
76
+ kernel_size=3,
77
+ stride=1,
78
+ padding=1)
79
+ self.norm2 = Normalize(out_channels)
80
+ self.dropout = torch.nn.Dropout(dropout)
81
+ self.conv2 = torch.nn.Conv2d(out_channels,
82
+ out_channels,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1)
86
+ if self.in_channels != self.out_channels:
87
+ if self.use_conv_shortcut:
88
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ else:
94
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
95
+ out_channels,
96
+ kernel_size=1,
97
+ stride=1,
98
+ padding=0)
99
+
100
+ def forward(self, x, temb=None):
101
+ assert temb is None
102
+
103
+ h = x
104
+ h = self.norm1(h)
105
+ h = nonlinearity(h)
106
+ h = self.conv1(h)
107
+
108
+ h = self.norm2(h)
109
+ h = nonlinearity(h)
110
+ h = self.dropout(h)
111
+ h = self.conv2(h)
112
+
113
+ if self.in_channels != self.out_channels:
114
+ if self.use_conv_shortcut:
115
+ x = self.conv_shortcut(x)
116
+ else:
117
+ x = self.nin_shortcut(x)
118
+ return x+h
119
+
120
+
121
+ class AttnBlock(nn.Module):
122
+ def __init__(self, in_channels):
123
+ super().__init__()
124
+ self.in_channels = in_channels
125
+
126
+ self.norm = Normalize(in_channels)
127
+ self.q = torch.nn.Conv2d(in_channels,
128
+ in_channels,
129
+ kernel_size=1,
130
+ stride=1,
131
+ padding=0)
132
+ self.k = torch.nn.Conv2d(in_channels,
133
+ in_channels,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0)
137
+ self.v = torch.nn.Conv2d(in_channels,
138
+ in_channels,
139
+ kernel_size=1,
140
+ stride=1,
141
+ padding=0)
142
+ self.proj_out = torch.nn.Conv2d(in_channels,
143
+ in_channels,
144
+ kernel_size=1,
145
+ stride=1,
146
+ padding=0)
147
+
148
+ def forward(self, x):
149
+ h_ = x
150
+ h_ = self.norm(h_)
151
+ q = self.q(h_)
152
+ k = self.k(h_)
153
+ v = self.v(h_)
154
+
155
+ # compute attention
156
+ b, c, h, w = q.shape
157
+ q = q.reshape(b, c, h*w)
158
+ q = q.permute(0, 2, 1) # b,hw,c
159
+ k = k.reshape(b, c, h*w) # b,c,hw
160
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
161
+ w_ = w_ * (int(c)**(-0.5))
162
+ w_ = torch.nn.functional.softmax(w_, dim=2)
163
+
164
+ # attend to values
165
+ v = v.reshape(b, c, h*w)
166
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
167
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
168
+ h_ = h_.reshape(b, c, h, w)
169
+
170
+ h_ = self.proj_out(h_)
171
+ return x+h_
172
+
173
+
174
+ class Encoder(nn.Module):
175
+ def __init__(self,
176
+ *, # forced to use named arguments
177
+ ch: int,
178
+ out_ch: int,
179
+ ch_mult: Tuple[int] = (1, 2, 4, 8),
180
+ num_res_blocks: int,
181
+ attn_resolutions: Tuple[int],
182
+ pdrop: float = 0.0,
183
+ resamp_with_conv: bool = True,
184
+ in_channels: int,
185
+ resolution: int,
186
+ z_channels: int,
187
+ double_z: Optional[bool] = None) -> None:
188
+ super().__init__()
189
+ self.ch = ch
190
+ self.temb_ch = 0
191
+ self.num_resolutions = len(ch_mult)
192
+ self.num_res_blocks = num_res_blocks
193
+ self.resolution = resolution
194
+ self.in_channels = in_channels
195
+
196
+ # downsampling
197
+ self.conv_in = torch.nn.Conv2d(in_channels,
198
+ self.ch,
199
+ kernel_size=3,
200
+ stride=1,
201
+ padding=1)
202
+
203
+ curr_res = resolution
204
+ in_ch_mult = (1,)+tuple(ch_mult)
205
+ self.down = nn.ModuleList()
206
+ for i_level in range(self.num_resolutions):
207
+ block = nn.ModuleList()
208
+ attn = nn.ModuleList()
209
+ block_in = ch*in_ch_mult[i_level]
210
+ block_out = ch*ch_mult[i_level]
211
+ for i_block in range(self.num_res_blocks):
212
+ block.append(ResnetBlock(in_channels=block_in,
213
+ out_channels=block_out,
214
+ temb_channels=self.temb_ch,
215
+ dropout=pdrop))
216
+ block_in = block_out
217
+ if curr_res in attn_resolutions:
218
+ attn.append(AttnBlock(block_in))
219
+ down = nn.Module()
220
+ down.block = block
221
+ down.attn = attn
222
+ if i_level != self.num_resolutions-1:
223
+ down.downsample = Downsample(block_in, resamp_with_conv)
224
+ curr_res = curr_res // 2
225
+ self.down.append(down)
226
+
227
+ # middle
228
+ self.mid = nn.Module()
229
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
230
+ out_channels=block_in,
231
+ temb_channels=self.temb_ch,
232
+ dropout=pdrop)
233
+ self.mid.attn_1 = AttnBlock(block_in)
234
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
235
+ out_channels=block_in,
236
+ temb_channels=self.temb_ch,
237
+ dropout=pdrop)
238
+
239
+ # end
240
+ self.norm_out = Normalize(block_in)
241
+ self.conv_out = torch.nn.Conv2d(block_in,
242
+ 2*z_channels if double_z else z_channels,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ def forward(self, x):
248
+ assert x.shape[2] == x.shape[3] == self.resolution, \
249
+ "{}, {}".format(x.shape, self.resolution)
250
+
251
+ # downsampling
252
+ h = self.conv_in(x)
253
+ for i_level in range(self.num_resolutions):
254
+ for i_block in range(self.num_res_blocks):
255
+ h = self.down[i_level].block[i_block](h)
256
+ if len(self.down[i_level].attn) > 0:
257
+ h = self.down[i_level].attn[i_block](h)
258
+ if i_level != self.num_resolutions-1:
259
+ h = self.down[i_level].downsample(h)
260
+
261
+ # middle
262
+ h = self.mid.block_1(h)
263
+ h = self.mid.attn_1(h)
264
+ h = self.mid.block_2(h)
265
+
266
+ # end
267
+ h = self.norm_out(h)
268
+ h = nonlinearity(h)
269
+ h = self.conv_out(h)
270
+ return h
271
+
272
+
273
+ class Decoder(nn.Module):
274
+ def __init__(self,
275
+ *, # forced to use named arguments
276
+ ch: int,
277
+ out_ch: int,
278
+ ch_mult: Tuple[int] = (1, 2, 4, 8),
279
+ num_res_blocks: int,
280
+ attn_resolutions: Tuple[int],
281
+ pdrop: float = 0.0,
282
+ resamp_with_conv: bool = True,
283
+ in_channels: int,
284
+ resolution: int,
285
+ z_channels: int,
286
+ double_z: bool) -> None:
287
+ super().__init__()
288
+ self.ch = ch
289
+ self.temb_ch = 0
290
+ self.num_resolutions = len(ch_mult)
291
+ self.num_res_blocks = num_res_blocks
292
+ self.resolution = resolution
293
+ self.in_channels = in_channels
294
+
295
+ # compute in_ch_mult, block_in and curr_res at lowest res
296
+ block_in = ch*ch_mult[self.num_resolutions-1]
297
+ curr_res = resolution // 2**(self.num_resolutions-1)
298
+ self.z_shape = (1, z_channels, curr_res, curr_res)
299
+
300
+ # z to block_in
301
+ self.conv_in = torch.nn.Conv2d(z_channels,
302
+ block_in,
303
+ kernel_size=3,
304
+ stride=1,
305
+ padding=1)
306
+
307
+ # middle
308
+ self.mid = nn.Module()
309
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
310
+ out_channels=block_in,
311
+ temb_channels=self.temb_ch,
312
+ dropout=pdrop)
313
+ self.mid.attn_1 = AttnBlock(block_in)
314
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
315
+ out_channels=block_in,
316
+ temb_channels=self.temb_ch,
317
+ dropout=pdrop)
318
+
319
+ # upsampling
320
+ self.up = nn.ModuleList()
321
+ for i_level in reversed(range(self.num_resolutions)):
322
+ block = nn.ModuleList()
323
+ attn = nn.ModuleList()
324
+ block_out = ch*ch_mult[i_level]
325
+ for i_block in range(self.num_res_blocks+1):
326
+ block.append(ResnetBlock(in_channels=block_in,
327
+ out_channels=block_out,
328
+ temb_channels=self.temb_ch,
329
+ dropout=pdrop))
330
+ block_in = block_out
331
+ if curr_res in attn_resolutions:
332
+ attn.append(AttnBlock(block_in))
333
+ up = nn.Module()
334
+ up.block = block
335
+ up.attn = attn
336
+ if i_level != 0:
337
+ up.upsample = Upsample(block_in, resamp_with_conv)
338
+ curr_res = curr_res * 2
339
+ self.up.insert(0, up) # prepend to get consistent order
340
+
341
+ # end
342
+ self.norm_out = Normalize(block_in)
343
+ self.conv_out = torch.nn.Conv2d(block_in,
344
+ out_ch,
345
+ kernel_size=3,
346
+ stride=1,
347
+ padding=1)
348
+
349
+ def forward(self, z):
350
+ assert z.shape[1:] == self.z_shape[1:]
351
+ self.last_z_shape = z.shape
352
+
353
+ # z to block_in
354
+ h = self.conv_in(z)
355
+
356
+ # middle
357
+ h = self.mid.block_1(h)
358
+ h = self.mid.attn_1(h)
359
+ h = self.mid.block_2(h)
360
+
361
+ # upsampling
362
+ for i_level in reversed(range(self.num_resolutions)):
363
+ for i_block in range(self.num_res_blocks+1):
364
+ h = self.up[i_level].block[i_block](h)
365
+ if len(self.up[i_level].attn) > 0:
366
+ h = self.up[i_level].attn[i_block](h)
367
+ if i_level != 0:
368
+ h = self.up[i_level].upsample(h)
369
+
370
+ h = self.norm_out(h)
371
+ h = nonlinearity(h)
372
+ h = self.conv_out(h)
373
+ return h
dalle/models/stage1/vqgan.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import List, Tuple, Optional
9
+ from einops import rearrange
10
+ from omegaconf import OmegaConf
11
+ from .layers import Encoder, Decoder
12
+
13
+
14
+ class VectorQuantizer(nn.Module):
15
+ """
16
+ Simplified VectorQuantizer in the original VQGAN repository
17
+ by removing unncessary modules for sampling
18
+ """
19
+ def __init__(self, dim: int, n_embed: int, beta: float) -> None:
20
+ super().__init__()
21
+ self.n_embed = n_embed
22
+ self.dim = dim
23
+ self.beta = beta
24
+
25
+ self.embedding = nn.Embedding(self.n_embed, self.dim)
26
+ self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed)
27
+
28
+ def forward(self,
29
+ z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
30
+ z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C]
31
+ z_flattened = z.view(-1, self.dim)
32
+
33
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
34
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
35
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
36
+
37
+ min_encoding_indices = torch.argmin(d, dim=1)
38
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
39
+ return z_q, min_encoding_indices
40
+
41
+ def get_codebook_entry(self,
42
+ indices: torch.LongTensor,
43
+ shape: Optional[List[int]] = None) -> torch.FloatTensor:
44
+ z_q = self.embedding(indices)
45
+ if shape is not None:
46
+ z_q = z_q.view(shape)
47
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
48
+ return z_q
49
+
50
+
51
+ class VQGAN(nn.Module):
52
+ def __init__(self, n_embed: int, embed_dim: int, hparams: OmegaConf) -> None:
53
+ super().__init__()
54
+ self.encoder = Encoder(**hparams)
55
+ self.decoder = Decoder(**hparams)
56
+ self.quantize = VectorQuantizer(dim=embed_dim, n_embed=n_embed, beta=0.25)
57
+ self.quant_conv = torch.nn.Conv2d(hparams.z_channels, embed_dim, 1)
58
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, hparams.z_channels, 1)
59
+ self.latent_dim = hparams.attn_resolutions[0]
60
+
61
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
62
+ quant = self.encode(x)
63
+ dec = self.decode(quant)
64
+ return dec
65
+
66
+ def encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
67
+ h = self.encoder(x)
68
+ h = self.quant_conv(h)
69
+ quant = self.quantize(h)[0]
70
+ quant = rearrange(quant, 'b h w c -> b c h w').contiguous()
71
+ return quant
72
+
73
+ def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
74
+ quant = self.post_quant_conv(quant)
75
+ dec = self.decoder(quant)
76
+ return dec
77
+
78
+ def decode_code(self, code: torch.LongTensor) -> torch.FloatTensor:
79
+ quant = self.quantize.get_codebook_entry(code)
80
+ quant = quant.permute(0, 3, 1, 2)
81
+ dec = self.decode(quant)
82
+ return dec
83
+
84
+ def get_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
85
+ h = self.encoder(x)
86
+ h = self.quant_conv(h)
87
+ codes = self.quantize(h)[1].view(x.shape[0], self.latent_dim ** 2)
88
+ return codes
89
+
90
+ def from_ckpt(self, path: str, strict: bool = True) -> None:
91
+ ckpt = torch.load(path, map_location='cpu')['state_dict']
92
+ self.load_state_dict(ckpt, strict=strict)
93
+ print(f'{path} successfully restored..')
dalle/models/stage2/__pycache__/layers.cpython-38.pyc ADDED
Binary file (5.71 kB). View file
 
dalle/models/stage2/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (11.7 kB). View file
 
dalle/models/stage2/layers.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+ # Modified from minGPT (https://github.com/karpathy/minGPT)
7
+ # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------------
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+
15
+
16
+ class GELU(nn.Module):
17
+ def __init__(self, use_approx=False):
18
+ super().__init__()
19
+ self.use_approx = use_approx
20
+
21
+ def forward(self, x):
22
+ if self.use_approx:
23
+ return x * torch.sigmoid(1.702 * x)
24
+ else:
25
+ return F.gelu(x)
26
+
27
+
28
+ class MultiHeadSelfAttention(nn.Module):
29
+
30
+ def __init__(self,
31
+ ctx_len: int,
32
+ embed_dim: int,
33
+ n_heads: int,
34
+ resid_pdrop: float,
35
+ attn_pdrop: float,
36
+ attn_bias: bool,
37
+ use_mask: bool = True):
38
+ super().__init__()
39
+ assert embed_dim % n_heads == 0
40
+
41
+ # key, query, value projections for all heads
42
+ self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
43
+ self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
44
+ self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
45
+
46
+ # regularization
47
+ self.attn_drop = nn.Dropout(attn_pdrop)
48
+ self.resid_drop = nn.Dropout(resid_pdrop)
49
+
50
+ # output projection
51
+ self.proj = nn.Linear(embed_dim, embed_dim, attn_bias)
52
+
53
+ self.n_heads = n_heads
54
+ self.ctx_len = ctx_len
55
+ self.use_mask = use_mask
56
+ if self.use_mask:
57
+ self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False)
58
+ self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len)
59
+
60
+ def forward(self, x, use_cache=False, layer_past=None):
61
+ B, T, C = x.shape
62
+ x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
63
+
64
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
65
+ k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
66
+ q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
67
+ v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
68
+
69
+ if use_cache:
70
+ present = torch.stack([k, v])
71
+
72
+ if layer_past is not None:
73
+ # print(layer_past.shape, k.shape, v.shape, q.shape)
74
+ # print("LayerPast shape", layer_past.shape)
75
+ past_key, past_value = layer_past
76
+
77
+ if len(past_key.shape) == 4:
78
+ _, _, seq_len, dim = past_key.shape
79
+ k = torch.cat([past_key.reshape(-1, seq_len, dim), k], dim=-2)
80
+ v = torch.cat([past_value.reshape(-1, seq_len, dim), v], dim=-2)
81
+ elif len(past_key.shape) == 3:
82
+ past_key, past_value = layer_past
83
+ k = torch.cat([past_key, k], dim=-2)
84
+ v = torch.cat([past_value, v], dim=-2)
85
+ else:
86
+ raise ValueError
87
+
88
+ if use_cache and layer_past is not None:
89
+ # Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K)
90
+ att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
91
+ att = F.softmax(att, dim=-1)
92
+ att = self.attn_drop(att)
93
+ y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs)
94
+ else:
95
+ # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T)
96
+ att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
97
+ if self.use_mask:
98
+ # TODO : Flip when not prompt tunign
99
+ # mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T]
100
+ if T == self.ctx_len:
101
+ mask = self.mask
102
+ else:
103
+ mask = torch.tril(torch.ones(T, T)).view(1, T, T).to(att.device)
104
+ att = att.masked_fill(mask == 0, float('-inf'))
105
+ att = F.softmax(att, dim=-1)
106
+ att = self.attn_drop(att)
107
+ y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
108
+ y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
109
+
110
+ # output projection
111
+ y = self.resid_drop(self.proj(y))
112
+ if use_cache:
113
+ return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C)
114
+ else:
115
+ return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C)
116
+
117
+ def forward_with_context(self, x, context, mask=None):
118
+ B, T, C = x.shape
119
+ x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
120
+
121
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
122
+ q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
123
+
124
+ B, T_c, C = context.shape
125
+ k = self.key(context).view(T_c, B * self.n_heads, C // self.n_heads).transpose(0, 1) # (B*nh, T, hs)
126
+ v = self.value(context).view(T_c, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
127
+
128
+ # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, Tc) -> (B * nh, T, Tc)
129
+ att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
130
+ att = F.softmax(att, dim=-1)
131
+ att = self.attn_drop(att)
132
+ y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
133
+ y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
134
+
135
+ # output projection
136
+ y = self.resid_drop(self.proj(y)).transpose(0, 1).contiguous()
137
+ if mask is not None:
138
+ y = y.masked_fill(mask == 0, float('0.0'))
139
+ return y # (T, B, C) -> (B, T, C)
140
+
141
+
142
+ class Block(nn.Module):
143
+
144
+ def __init__(self,
145
+ ctx_len: int,
146
+ embed_dim: int,
147
+ n_heads: int,
148
+ mlp_bias: bool,
149
+ attn_bias: bool,
150
+ resid_pdrop: bool,
151
+ attn_pdrop: bool,
152
+ gelu_use_approx: bool):
153
+ super().__init__()
154
+ self.ln1 = nn.LayerNorm(embed_dim)
155
+ self.ln2 = nn.LayerNorm(embed_dim)
156
+
157
+ self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
158
+ embed_dim=embed_dim,
159
+ n_heads=n_heads,
160
+ attn_pdrop=attn_pdrop,
161
+ resid_pdrop=resid_pdrop,
162
+ attn_bias=attn_bias,
163
+ use_mask=True)
164
+ self.mlp = nn.Sequential(
165
+ nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias),
166
+ GELU(gelu_use_approx),
167
+ nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias),
168
+ nn.Dropout(resid_pdrop),
169
+ )
170
+
171
+ def forward(self, x, layer_past=None):
172
+ x = x + self.attn(self.ln1(x), layer_past=layer_past)
173
+ x = x + self.mlp(self.ln2(x))
174
+ return x
175
+
176
+ def sample(self, x, layer_past=None):
177
+ attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
178
+ x = x + attn
179
+ x = x + self.mlp(self.ln2(x))
180
+ return x, present
181
+
182
+ def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past=None):
183
+ attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
184
+ x = x + attn
185
+ c_attn = cross_attn_layer(x, context, context_mask)
186
+ x = x + c_attn
187
+ x = x + self.mlp(self.ln2(x))
188
+ return x, present
189
+
190
+
191
+ class CrossAttentionLayer(nn.Module):
192
+
193
+ def __init__(self,
194
+ ctx_len: int,
195
+ embed_dim: int,
196
+ n_heads: int,
197
+ attn_bias: bool,
198
+ resid_pdrop: bool,
199
+ attn_pdrop: bool):
200
+ super().__init__()
201
+
202
+ self.ln1 = nn.LayerNorm(embed_dim)
203
+ self.ln2 = nn.LayerNorm(embed_dim)
204
+ self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
205
+ embed_dim=embed_dim,
206
+ n_heads=n_heads,
207
+ attn_pdrop=attn_pdrop,
208
+ resid_pdrop=resid_pdrop,
209
+ attn_bias=attn_bias,
210
+ use_mask=False)
211
+
212
+ def forward(self, x, context, context_mask=None):
213
+ attn = self.attn.forward_with_context(self.ln1(x), self.ln2(context), context_mask)
214
+ # x = x + attn
215
+ # return x
216
+ return attn
dalle/models/stage2/transformer.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+ # Modified from minGPT (https://github.com/karpathy/minGPT)
7
+ # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------------
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Optional, Tuple, List
13
+ from torch.cuda.amp import autocast
14
+ from omegaconf import OmegaConf
15
+ from .layers import Block
16
+
17
+ class Transformer1d(nn.Module):
18
+
19
+ def __init__(self,
20
+ vocab_size_txt: int,
21
+ vocab_size_img: int,
22
+ hparams: OmegaConf) -> None:
23
+ super().__init__()
24
+ assert hparams.n_layers == hparams.n_dense_layers
25
+
26
+ # input embedding for image and text
27
+ self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
28
+ self.tok_emb_txt = nn.Embedding(vocab_size_txt, hparams.embed_dim)
29
+
30
+ self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
31
+ self.pos_emb_txt = nn.Embedding(hparams.ctx_len_txt, hparams.embed_dim)
32
+
33
+ self.drop = nn.Dropout(hparams.embd_pdrop)
34
+
35
+ # transformer blocks
36
+ self.blocks = [Block(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
37
+ embed_dim=hparams.embed_dim,
38
+ n_heads=hparams.n_heads,
39
+ mlp_bias=hparams.mlp_bias,
40
+ attn_bias=hparams.attn_bias,
41
+ resid_pdrop=hparams.resid_pdrop,
42
+ attn_pdrop=hparams.attn_pdrop,
43
+ gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
44
+ self.blocks = nn.Sequential(*self.blocks)
45
+
46
+ # heads for image and text
47
+ self.ln_f = nn.LayerNorm(hparams.embed_dim)
48
+ self.head_img = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
49
+ self.head_txt = nn.Linear(hparams.embed_dim, vocab_size_txt, bias=False)
50
+
51
+ self.ctx_len_img = hparams.ctx_len_img
52
+ self.ctx_len_txt = hparams.ctx_len_txt
53
+ self.n_layers = hparams.n_layers
54
+
55
+ self.apply(self._init_weights)
56
+
57
+
58
+ def _init_weights(self, module: nn.Module) -> None:
59
+ if isinstance(module, (nn.Linear, nn.Embedding)):
60
+ module.weight.data.normal_(mean=0.0, std=0.02)
61
+ if isinstance(module, nn.Linear) and module.bias is not None:
62
+ module.bias.data.zero_()
63
+ elif isinstance(module, nn.LayerNorm):
64
+ module.bias.data.zero_()
65
+ module.weight.data.fill_(1.0)
66
+
67
+
68
+ def resize_token_embeddings(self, new_num_tokens):
69
+
70
+ old_num_tokens, old_embedding_dim = self.tok_emb_txt.weight.size()
71
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
72
+ new_embeddings.to(self.tok_emb_txt.weight.device, dtype=self.tok_emb_txt.weight.dtype)
73
+ self._init_weights(new_embeddings)
74
+ # numbers of tokens to copy
75
+ n = min(old_num_tokens, new_num_tokens)
76
+ new_embeddings.weight.data[:n, :] = self.tok_emb_txt.weight.data[:n, :]
77
+ self.tok_emb_txt = new_embeddings
78
+
79
+ self.resize_lm_head(new_num_tokens)
80
+ # TODO: also change config to reflect new vocab size
81
+
82
+ return new_embeddings
83
+
84
+
85
+ def resize_lm_head(
86
+ self, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False) -> nn.Linear:
87
+
88
+ old_num_tokens, old_lm_head_dim = (
89
+ self.head_txt.weight.size() if not transposed else self.head_txt.weight.t().size()
90
+ )
91
+ # Build new lm head
92
+ new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
93
+ has_new_lm_head_bias = self.head_txt.bias is not None
94
+ new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
95
+ new_lm_head = new_lm_head.to(self.head_txt.weight.device, dtype=self.head_txt.weight.dtype)
96
+
97
+ # initialize new lm head (in particular added tokens)
98
+ self._init_weights(new_lm_head)
99
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
100
+ # Copy old lm head weights to new lm head
101
+ if not transposed:
102
+ new_lm_head.weight.data[:num_tokens_to_copy, :] = self.head_txt.weight.data[:num_tokens_to_copy, :]
103
+ else:
104
+ new_lm_head.weight.data[:, :num_tokens_to_copy] = self.head_txt.weight.data[:, :num_tokens_to_copy]
105
+
106
+ # Copy bias weights to new lm head
107
+ if has_new_lm_head_bias:
108
+ new_lm_head.bias.data[:num_tokens_to_copy] = self.head_txt.bias.data[:num_tokens_to_copy]
109
+
110
+ self.head_txt = new_lm_head
111
+
112
+ return new_lm_head
113
+
114
+
115
+ def forward(self,
116
+ images: torch.LongTensor,
117
+ texts: torch.LongTensor,
118
+ pos_images: torch.LongTensor,
119
+ pos_texts: torch.LongTensor,
120
+ past: Optional[List[torch.Tensor]] = None,
121
+ prompt: Optional[List[torch.Tensor]] = None,
122
+ pos_prompt: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
123
+
124
+
125
+ B, T = images.shape
126
+ _, N = texts.shape
127
+
128
+ assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
129
+ assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
130
+
131
+ texts = self.tok_emb_txt(texts)
132
+ images = self.tok_emb_img(images)
133
+
134
+ texts = texts + self.pos_emb_txt(pos_texts)
135
+ images = images + self.pos_emb_img(pos_images)
136
+
137
+ if prompt is not None:
138
+ prompt = prompt + self.pos_emb_txt(pos_prompt)
139
+ texts = torch.cat([prompt, texts], dim=1).contiguous()
140
+ P = prompt.shape[1]
141
+
142
+ x = torch.cat([texts, images], dim=1).contiguous()
143
+ x = self.drop(x)
144
+
145
+ # x = self.blocks(x)
146
+ for i, block in enumerate(self.blocks):
147
+ x, _ = block.sample(x, layer_past=None if past is None else past[i])
148
+
149
+ x = self.ln_f(x)
150
+
151
+ if prompt is not None:
152
+ texts = x[:, P:N+P-1].contiguous()
153
+ images = x[:, N+P-1:-1].contiguous()
154
+ else:
155
+ texts = x[:, :N-1].contiguous()
156
+ images = x[:, N-1:-1].contiguous()
157
+
158
+ logits_txt = self.head_txt(texts)
159
+ logits_img = self.head_img(images)
160
+ return logits_img, logits_txt
161
+
162
+ def forward_with_context(self,
163
+ images: torch.LongTensor,
164
+ texts: torch.LongTensor,
165
+ pos_images: torch.LongTensor,
166
+ pos_texts: torch.LongTensor,
167
+ src_images: torch.LongTensor,
168
+ src_pos_images: torch.LongTensor,
169
+ cross_attention_idxs: List,
170
+ cross_attention_layers,
171
+ past: Optional[List[torch.Tensor]] = None,
172
+ prompt: Optional[List[torch.Tensor]] = None,
173
+ pos_prompt: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
174
+
175
+
176
+ B, T = images.shape
177
+ _, N = texts.shape
178
+
179
+ assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
180
+ assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
181
+
182
+ texts = self.tok_emb_txt(texts)
183
+ images = self.tok_emb_img(images)
184
+ src_images = self.tok_emb_img(src_images)
185
+
186
+ texts = texts + self.pos_emb_txt(pos_texts)
187
+ images = images + self.pos_emb_img(pos_images)
188
+ src_images = src_images + self.pos_emb_img(src_pos_images)
189
+
190
+ if prompt is not None:
191
+ prompt = prompt + self.pos_emb_txt(pos_prompt)
192
+ texts = torch.cat([prompt, texts], dim=1).contiguous()
193
+ P = prompt.shape[1]
194
+ else:
195
+ P = 0
196
+
197
+ x = torch.cat([texts, images], axis=1).contiguous()
198
+ x = self.drop(x)
199
+
200
+ # prepare mask
201
+ mask = torch.zeros_like(x[0])
202
+ mask[self.ctx_len_txt+P-1:, :].fill_(1.0)
203
+ mask = mask.unsqueeze(0)
204
+
205
+ # print(images.shape, texts.shape, src_images.shape, mask.shape, x.shape)
206
+
207
+ # x = self.blocks(x)
208
+ for i, block in enumerate(self.blocks):
209
+ if i in cross_attention_idxs:
210
+ x, _ = block.sample_with_context(x, src_images, mask, cross_attention_layers[int(((i+1)/3)-1)], layer_past=None if past is None else past[i])
211
+ else:
212
+ x, _ = block.sample(x, layer_past=None if past is None else past[i])
213
+
214
+ x = self.ln_f(x)
215
+
216
+ if prompt is not None:
217
+ texts = x[:, P:N+P-1].contiguous()
218
+ images = x[:, N+P-1:-1].contiguous()
219
+ else:
220
+ texts = x[:, :N-1].contiguous()
221
+ images = x[:, N-1:-1].contiguous()
222
+
223
+ logits_txt = self.head_txt(texts)
224
+ logits_img = self.head_img(images)
225
+ return logits_img, logits_txt
226
+
227
+ @torch.no_grad()
228
+ def sampling(self,
229
+ images: torch.LongTensor,
230
+ texts: torch.LongTensor,
231
+ pos_images: torch.LongTensor,
232
+ pos_texts: torch.LongTensor,
233
+ use_fp16: bool = True,
234
+ past: Optional[List[torch.Tensor]] = None,
235
+ prompt: Optional[List[torch.Tensor]] = None,
236
+ pos_prompt: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
237
+
238
+ _, N = texts.shape
239
+ assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
240
+
241
+ with autocast(enabled=use_fp16):
242
+ if images is None:
243
+ # assert past is None
244
+
245
+ texts = self.tok_emb_txt(texts)
246
+ x = texts + self.pos_emb_txt(pos_texts)
247
+
248
+ if prompt is not None:
249
+ prompt = prompt + self.pos_emb_txt(pos_prompt)
250
+ texts = torch.cat([prompt, texts], dim=1).contiguous()
251
+
252
+ x = self.drop(x)
253
+
254
+ if past is not None:
255
+ past = torch.cat(past, dim=-2)
256
+
257
+ presents = []
258
+ for i, block in enumerate(self.blocks):
259
+ x, present = block.sample(x, layer_past=None if past is None else past[i])
260
+ presents.append(present)
261
+ x = self.ln_f(x)
262
+ x = x[:, N-1].contiguous()
263
+ logits = self.head_img(x)
264
+ else:
265
+ if past is None:
266
+ texts = self.tok_emb_txt(texts)
267
+ images = self.tok_emb_img(images)
268
+ texts = texts + self.pos_emb_txt(pos_texts)
269
+ images = images + self.pos_emb_img(pos_images)
270
+
271
+ if prompt is not None:
272
+ prompt = prompt + self.pos_emb_txt(pos_prompt)
273
+ texts = torch.cat([prompt, texts], dim=1).contiguous()
274
+
275
+ x = torch.cat([texts, images], axis=1).contiguous()
276
+ else:
277
+ images = self.tok_emb_img(images)
278
+ x = images + self.pos_emb_img(pos_images)
279
+ x = self.drop(x)
280
+
281
+ # if past is not None and len(past) > 1:
282
+ if past is not None:
283
+ past = torch.cat(past, dim=-2)
284
+ # print('Past', past.shape)
285
+ presents = []
286
+ # print(len(past), past[0].shape)
287
+ for i, block in enumerate(self.blocks):
288
+ x, present = block.sample(x, layer_past=None if past is None else past[i])
289
+ presents.append(present)
290
+ x = self.ln_f(x)
291
+ x = x[:, -1].contiguous()
292
+ logits = self.head_img(x)
293
+ return logits, presents
294
+
295
+ @torch.no_grad()
296
+ def sampling_with_context(self,
297
+ images: torch.LongTensor,
298
+ cross_attention_idxs,
299
+ cross_attention_layers,
300
+ texts: torch.LongTensor,
301
+ pos_images: torch.LongTensor,
302
+ pos_texts: torch.LongTensor,
303
+ source_image: torch.LongTensor,
304
+ use_fp16: bool = True,
305
+ past: Optional[List[torch.Tensor]] = None,
306
+ prompt: Optional[List[torch.Tensor]] = None,
307
+ pos_prompt: Optional[List[torch.Tensor]] = None
308
+ ) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
309
+
310
+ _, N = texts.shape
311
+ assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
312
+
313
+ if prompt is not None:
314
+ P = prompt.shape[1]
315
+ else:
316
+ P = 0
317
+
318
+ with autocast(enabled=use_fp16):
319
+ if images is None:
320
+ # assert past is None
321
+
322
+ texts = self.tok_emb_txt(texts)
323
+ texts = texts + self.pos_emb_txt(pos_texts)
324
+
325
+ if prompt is not None:
326
+ prompt = prompt + self.pos_emb_txt(pos_prompt)
327
+ texts = torch.cat([prompt, texts], dim=1).contiguous()
328
+
329
+ x = self.drop(texts)
330
+
331
+ if past is not None:
332
+ past = torch.cat(past, dim=-2)
333
+
334
+ # prepare mask
335
+ mask = torch.zeros_like(x[0])
336
+ mask[self.ctx_len_txt+P - 1:, :].fill_(1.0)
337
+ mask = mask.unsqueeze(0)
338
+
339
+ presents = []
340
+ for i, block in enumerate(self.blocks):
341
+ if i in cross_attention_idxs:
342
+ x, present = block.sample_with_context(x, source_image, mask,
343
+ cross_attention_layers[int(((i + 1) / 3) - 1)],
344
+ layer_past=None if past is None else past[i])
345
+ else:
346
+ x, present = block.sample(x, layer_past=None if past is None else past[i])
347
+ presents.append(present)
348
+ x = self.ln_f(x)
349
+ x = x[:, N-1].contiguous()
350
+ logits = self.head_img(x)
351
+ else:
352
+ if past is None:
353
+ texts = self.tok_emb_txt(texts)
354
+ images = self.tok_emb_img(images)
355
+ texts = texts + self.pos_emb_txt(pos_texts)
356
+ images = images + self.pos_emb_img(pos_images)
357
+
358
+ if prompt is not None:
359
+ prompt = prompt + self.pos_emb_txt(pos_prompt)
360
+ texts = torch.cat([prompt, texts], dim=1).contiguous()
361
+
362
+ x = torch.cat([texts, images], axis=1).contiguous()
363
+ else:
364
+ images = self.tok_emb_img(images)
365
+ x = images + self.pos_emb_img(pos_images)
366
+ x = self.drop(x)
367
+
368
+ # if past is not None and len(past) > 1:
369
+ if past is not None:
370
+ past = torch.cat(past, dim=-2)
371
+ presents = []
372
+
373
+ # prepare mask
374
+ mask = torch.zeros_like(x[0])
375
+ mask[self.ctx_len_txt+P - 1:, :].fill_(1.0)
376
+ mask = mask.unsqueeze(0)
377
+
378
+ # print(len(past), past[0].shape)
379
+ for i, block in enumerate(self.blocks):
380
+ if i in cross_attention_idxs:
381
+ x, present = block.sample_with_context(x, source_image, mask,
382
+ cross_attention_layers[int(((i + 1) / 3) - 1)],
383
+ layer_past=None if past is None else past[i])
384
+ else:
385
+ x, present = block.sample(x, layer_past=None if past is None else past[i])
386
+ presents.append(present)
387
+ x = self.ln_f(x)
388
+ x = x[:, -1].contiguous()
389
+ logits = self.head_img(x)
390
+ return logits, presents
391
+
392
+ def from_ckpt(self, path: str) -> None:
393
+ ckpt = torch.load(path, map_location='cpu')['state_dict']
394
+ self.load_state_dict(ckpt, strict=True)
395
+ print(f'{path} succesfully restored..')
396
+
397
+
398
+ class iGPT(nn.Module):
399
+ def __init__(self,
400
+ vocab_size_img: int,
401
+ use_cls_cond: bool,
402
+ hparams: OmegaConf) -> None:
403
+ super().__init__()
404
+ self.use_cls_cond = use_cls_cond
405
+
406
+ # sos token embedding
407
+ if self.use_cls_cond:
408
+ self.sos = nn.Embedding(hparams.n_classes, hparams.embed_dim)
409
+ else:
410
+ self.sos = nn.Parameter(torch.randn(1, 1, hparams.embed_dim))
411
+
412
+ # input embedding
413
+ self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
414
+ self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
415
+
416
+ self.drop = nn.Dropout(hparams.embd_pdrop)
417
+
418
+ # transformer blocks
419
+ self.blocks = [Block(ctx_len=hparams.ctx_len_img + 1,
420
+ embed_dim=hparams.embed_dim,
421
+ n_heads=hparams.n_heads,
422
+ mlp_bias=hparams.mlp_bias,
423
+ attn_bias=hparams.attn_bias,
424
+ resid_pdrop=hparams.resid_pdrop,
425
+ attn_pdrop=hparams.attn_pdrop,
426
+ gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
427
+ self.blocks = nn.Sequential(*self.blocks)
428
+
429
+ # head
430
+ self.ln_f = nn.LayerNorm(hparams.embed_dim)
431
+ self.head = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
432
+
433
+ self.ctx_len_img = hparams.ctx_len_img
434
+ self.n_layers = hparams.n_layers
435
+
436
+ self.apply(self._init_weights)
437
+
438
+ def _init_weights(self, module: nn.Module) -> None:
439
+ if isinstance(module, (nn.Linear, nn.Embedding)):
440
+ module.weight.data.normal_(mean=0.0, std=0.02)
441
+ if isinstance(module, nn.Linear) and module.bias is not None:
442
+ module.bias.data.zero_()
443
+ elif isinstance(module, nn.LayerNorm):
444
+ module.bias.data.zero_()
445
+ module.weight.data.fill_(1.0)
446
+
447
+ @torch.no_grad()
448
+ def sampling(self,
449
+ sos: torch.FloatTensor,
450
+ codes: torch.LongTensor,
451
+ pos_codes: torch.LongTensor,
452
+ n_samples: int = 16,
453
+ use_fp16: bool = True,
454
+ past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
455
+ with autocast(enabled=use_fp16):
456
+ if codes is None:
457
+ assert past is None
458
+ xs = self.drop(sos)
459
+ presents = []
460
+ for i, block in enumerate(self.blocks):
461
+ xs, present = block.sample(xs, layer_past=None)
462
+ presents.append(present)
463
+ xs = self.ln_f(xs)
464
+ logits = self.head(xs)[:, -1]
465
+ else:
466
+ if past is None:
467
+ xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
468
+ xs = torch.cat([sos, xs], dim=1)
469
+ else:
470
+ xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
471
+ xs = self.drop(xs)
472
+
473
+ past = torch.cat(past, dim=-2) if past is not None else past
474
+ presents = []
475
+ for i, block in enumerate(self.blocks):
476
+ xs, present = block.sample(xs, layer_past=None if past is None else past[i])
477
+ presents.append(present)
478
+
479
+ xs = self.ln_f(xs)
480
+ logits = self.head(xs)[:, -1]
481
+ return logits, presents
482
+
483
+ def forward(self,
484
+ codes: torch.LongTensor,
485
+ labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
486
+ B, T = codes.shape
487
+ xps = torch.arange(T, device=codes.device).repeat((B, 1))
488
+ sos = self.sos.repeat((B, 1, 1)) if labels is None else self.sos(labels).unsqueeze(1)
489
+
490
+ h = self.tok_emb_img(codes) + self.pos_emb_img(xps)
491
+ h = torch.cat([sos, h[:, :-1]], dim=1).contiguous()
492
+
493
+ h = self.drop(h)
494
+ h = self.blocks(h)
495
+ h = self.ln_f(h)
496
+ logits = self.head(h)
497
+ return logits
498
+
499
+ def from_ckpt(self, path: str, strict: bool = True) -> None:
500
+ ckpt = torch.load(path, map_location='cpu')['state_dict']
501
+ self.load_state_dict(ckpt, strict=strict)
502
+ print(f'{path} successfully restored..')
dalle/models/tokenizer.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ from functools import partial
9
+ from tokenizers import CharBPETokenizer
10
+
11
+
12
+ def build_tokenizer(path: str,
13
+ context_length: int = 64,
14
+ *args,
15
+ **kwargs):
16
+ try:
17
+ from_file = partial(CharBPETokenizer.from_file,
18
+ vocab_filename=os.path.join(path, 'bpe-16k-vocab.json'),
19
+ merges_filename=os.path.join(path, 'bpe-16k-merges.txt'),
20
+ unk_token='[UNK]')
21
+ tokenizer = from_file(*args, **kwargs)
22
+ except:
23
+ from_file = partial(CharBPETokenizer.from_file,
24
+ vocab_filename=os.path.join(path, 'vocab.json'),
25
+ merges_filename=os.path.join(path, 'merges.txt'),
26
+ unk_token='[UNK]')
27
+ tokenizer = from_file(*args, **kwargs)
28
+
29
+ # tokenizer = from_file(*args, **kwargs)
30
+ tokenizer.add_special_tokens(['[PAD]'])
31
+ tokenizer.enable_padding(length=context_length,
32
+ pad_id=tokenizer.token_to_id('[PAD]'))
33
+ tokenizer.enable_truncation(max_length=context_length)
34
+ print(f'{path} successfully restored..')
35
+ return tokenizer
dalle/trainer_prefix.py ADDED
@@ -0,0 +1,1629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import json
3
+ import math
4
+ import os
5
+ import re
6
+ import shutil
7
+ import warnings
8
+ from contextlib import contextmanager
9
+ from pathlib import Path
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
11
+
12
+ from nltk import word_tokenize
13
+ import numpy as np
14
+ import torch
15
+ from packaging import version
16
+ from torch import nn
17
+ from torch.utils.data.dataloader import DataLoader
18
+ from torch.utils.data.dataset import Dataset
19
+ from torch.utils.data.distributed import DistributedSampler
20
+ from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
21
+ from tqdm.auto import tqdm, trange
22
+ from torch.nn.utils.rnn import pad_sequence
23
+ import random
24
+
25
+ from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
26
+ from transformers.file_utils import is_datasets_available, is_torch_tpu_available
27
+ from transformers.integrations import (
28
+ default_hp_search_backend,
29
+ is_comet_available,
30
+ is_optuna_available,
31
+ is_ray_available,
32
+ is_tensorboard_available,
33
+ is_wandb_available,
34
+ run_hp_search_optuna,
35
+ run_hp_search_ray,
36
+ )
37
+
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.optimization import AdamW, get_linear_schedule_with_warmup, get_constant_schedule_with_warmup
40
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
41
+ from transformers.trainer_utils import (
42
+ PREFIX_CHECKPOINT_DIR,
43
+ BestRun,
44
+ EvalPrediction,
45
+ EvaluationStrategy,
46
+ HPSearchBackend,
47
+ PredictionOutput,
48
+ TrainOutput,
49
+ default_compute_objective,
50
+ default_hp_space,
51
+ set_seed,
52
+ )
53
+ from transformers.training_args import TrainingArguments
54
+ from transformers.utils import logging
55
+
56
+
57
+ _use_native_amp = False
58
+ _use_apex = False
59
+ EPS = 1e-12
60
+ INIT_GUMBEL_TEMP = 5.0
61
+
62
+ control_lst = ['positive', 'negative', 'neutral']
63
+ Control_Temp = {'positive': 3967, 'negative':4633, 'neutral':8500}
64
+ control_Map = [torch.LongTensor([3967]), torch.LongTensor([4633]), torch.LongTensor([8500])]
65
+ sst_lst = [(0, 2), (1, 3), (4,)]
66
+ sst_standard = ["positive", "negative", "very positive", "very negative", "neutral"]
67
+ # Control_?Map = {j:i for i, j in enumerate(control_lst)}
68
+
69
+ # Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
70
+ if version.parse(torch.__version__) < version.parse("1.6"):
71
+ from transformers.file_utils import is_apex_available
72
+
73
+ if is_apex_available():
74
+ from apex import amp
75
+ _use_apex = True
76
+ else:
77
+ _use_native_amp = True
78
+ from torch.cuda.amp import autocast
79
+
80
+ if is_datasets_available():
81
+ import datasets
82
+
83
+ if is_torch_tpu_available():
84
+ import torch_xla.core.xla_model as xm
85
+ import torch_xla.debug.metrics as met
86
+ import torch_xla.distributed.parallel_loader as pl
87
+
88
+ if is_tensorboard_available():
89
+ try:
90
+ from torch.utils.tensorboard import SummaryWriter
91
+ except ImportError:
92
+ from tensorboardX import SummaryWriter
93
+
94
+ if is_wandb_available():
95
+ import wandb
96
+
97
+ if is_comet_available():
98
+ import comet_ml
99
+
100
+ if is_optuna_available():
101
+ import optuna
102
+
103
+ if is_ray_available():
104
+ from ray import tune
105
+
106
+
107
+ logger = logging.get_logger(__name__)
108
+
109
+
110
+ @contextmanager
111
+ def torch_distributed_zero_first(local_rank: int):
112
+ """
113
+ Decorator to make all processes in distributed training wait for each local_master to do something.
114
+
115
+ Args:
116
+ local_rank (:obj:`int`): The rank of the local process.
117
+ """
118
+ if local_rank not in [-1, 0]:
119
+ torch.distributed.barrier()
120
+ yield
121
+ if local_rank == 0:
122
+ torch.distributed.barrier()
123
+
124
+ def helper_token2bpe(offsets):
125
+ full_lst = []
126
+ for example_offset in offsets:
127
+ bpe2token = []
128
+ token2bpe = []
129
+ token_idx = -1
130
+ # print(example_offset)
131
+ for bpe_idx, (a,b) in enumerate(example_offset):
132
+ # print(token2bpe, a, b, bpe_idx)
133
+ if b - a > 0:
134
+ if a == 0:
135
+ # new token
136
+ token_idx += 1
137
+ bpe2token.append(token_idx)
138
+ token2bpe.append([])
139
+ token2bpe[-1].append(bpe_idx)
140
+ else:
141
+ # prev token.
142
+ bpe2token.append(token_idx)
143
+ token2bpe[-1].append(bpe_idx)
144
+ else:
145
+ bpe2token.append(None)
146
+ full_lst.append((bpe2token, token2bpe))
147
+ return full_lst
148
+
149
+ class SequentialDistributedSampler(Sampler):
150
+ """
151
+ Distributed Sampler that subsamples indicies sequentially,
152
+ making it easier to collate all results at the end.
153
+
154
+ Even though we only use this sampler for eval and predict (no training),
155
+ which means that the model params won't have to be synced (i.e. will not hang
156
+ for synchronization even if varied number of forward passes), we still add extra
157
+ samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
158
+ to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
159
+ """
160
+
161
+ def __init__(self, dataset, num_replicas=None, rank=None):
162
+ if num_replicas is None:
163
+ if not torch.distributed.is_available():
164
+ raise RuntimeError("Requires distributed package to be available")
165
+ num_replicas = torch.distributed.get_world_size()
166
+ if rank is None:
167
+ if not torch.distributed.is_available():
168
+ raise RuntimeError("Requires distributed package to be available")
169
+ rank = torch.distributed.get_rank()
170
+ self.dataset = dataset
171
+ self.num_replicas = num_replicas
172
+ self.rank = rank
173
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
174
+ self.total_size = self.num_samples * self.num_replicas
175
+
176
+ def __iter__(self):
177
+ indices = list(range(len(self.dataset)))
178
+
179
+ # add extra samples to make it evenly divisible
180
+ indices += indices[: (self.total_size - len(indices))]
181
+ assert (
182
+ len(indices) == self.total_size
183
+ ), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
184
+
185
+ # subsample
186
+ indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
187
+ assert (
188
+ len(indices) == self.num_samples
189
+ ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
190
+
191
+ return iter(indices)
192
+
193
+ def __len__(self):
194
+ return self.num_samples
195
+
196
+
197
+ def get_tpu_sampler(dataset: Dataset):
198
+ if xm.xrt_world_size() <= 1:
199
+ return RandomSampler(dataset)
200
+ return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
201
+
202
+
203
+ class Trainer_Prefix:
204
+ """
205
+ Trainer is a simple but feature-complete training and eval loop for PyTorch,
206
+ optimized for 🤗 Transformers.
207
+
208
+ Args:
209
+ model (:class:`~transformers.PreTrainedModel`, `optional`):
210
+ The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
211
+ args (:class:`~transformers.TrainingArguments`, `optional`):
212
+ The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
213
+ with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
214
+ data_collator (:obj:`DataCollator`, `optional`):
215
+ The function to use to form a batch from a list of elements of :obj:`train_dataset` or
216
+ :obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
217
+ provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
218
+ train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
219
+ The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
220
+ ``model.forward()`` method are automatically removed.
221
+ eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
222
+ The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
223
+ ``model.forward()`` method are automatically removed.
224
+ tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
225
+ The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
226
+ maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
227
+ interrupted training or reuse the fine-tuned model.
228
+ model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
229
+ A function that instantiates the model to be used. If provided, each call to
230
+ :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
231
+ compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
232
+ The function that will be used to compute metrics at evaluation. Must take a
233
+ :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
234
+ tb_writer (:obj:`SummaryWriter`, `optional`):
235
+ Object to write to TensorBoard.
236
+ optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
237
+ A tuple containing the optimizer and the scheduler to use. Will default to an instance of
238
+ :class:`~transformers.AdamW` on your model and a scheduler given by
239
+ :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
240
+ kwargs:
241
+ Deprecated keyword arguments.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ model: Optional[PreTrainedModel] = None,
247
+ model_gpt2 : Optional[PreTrainedModel] = None,
248
+ args: TrainingArguments = None,
249
+ data_collator: Optional[DataCollator] = None,
250
+ train_dataset: Optional[Dataset] = None,
251
+ eval_dataset: Optional[Dataset] = None,
252
+ tokenizer: Optional["PreTrainedTokenizerBase"] = None,
253
+ model_init: Callable[[], PreTrainedModel] = None,
254
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
255
+ tb_writer: Optional["SummaryWriter"] = None,
256
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
257
+ task_mode: Optional[str] = None,
258
+ use_dropout: Optional[bool] = False,
259
+ distill: Optional[bool] = False,
260
+ matching_objective:Optional[str]= None,
261
+ finetuned_gpt2: Optional[PreTrainedModel] = None,
262
+ **kwargs,
263
+ ):
264
+ if args is None:
265
+ logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
266
+ args = TrainingArguments("tmp_trainer")
267
+ self.args = args
268
+ # Seed must be set before instantiating the model when using model
269
+ set_seed(self.args.seed)
270
+ assert (
271
+ model is not None or model_init is not None
272
+ ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
273
+ assert model_init is None
274
+ self.model = model.to(args.device) if model is not None else None
275
+ self.gpt2 = model_gpt2.to(args.device) if model_gpt2 is not None else None
276
+ default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
277
+ self.data_collator = data_collator if data_collator is not None else default_collator
278
+ self.train_dataset = train_dataset
279
+ self.eval_dataset = eval_dataset
280
+ self.tokenizer = tokenizer
281
+ self.model_init = model_init
282
+ self.compute_metrics = compute_metrics
283
+ self.optimizer, self.lr_scheduler = optimizers
284
+ self.task_mode = task_mode
285
+ self.use_dropout = use_dropout
286
+
287
+ self.curr_best_eval = 10000000.
288
+
289
+ self.distill = distill
290
+ if self.distill:
291
+ self.matching_objective = matching_objective
292
+ self.finetuned_gpt2 = finetuned_gpt2
293
+
294
+ if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
295
+ raise RuntimeError(
296
+ "Passing a `model_init` is incompatible with providing the `optimizers` argument."
297
+ "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
298
+ )
299
+ self.tb_writer = tb_writer
300
+ self.log_history = []
301
+ if "prediction_loss_only" in kwargs:
302
+ warnings.warn(
303
+ "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
304
+ FutureWarning,
305
+ )
306
+ self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
307
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
308
+
309
+ if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
310
+ self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
311
+ if not is_tensorboard_available():
312
+ logger.warning(
313
+ "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
314
+ )
315
+
316
+ # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
317
+ self._loggers_initialized = False
318
+
319
+ # Create output directory if needed
320
+ if self.is_world_process_zero():
321
+ os.makedirs(self.args.output_dir, exist_ok=True)
322
+ if is_torch_tpu_available():
323
+ # Set an xla_device flag on the model's config.
324
+ # We'll find a more elegant and not need to do this in the future.
325
+ self.model.config.xla_device = True
326
+ if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
327
+ self.data_collator = self.data_collator.collate_batch
328
+ warnings.warn(
329
+ (
330
+ "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
331
+ + "with a `collate_batch` are deprecated and won't be supported in a future version."
332
+ ),
333
+ FutureWarning,
334
+ )
335
+
336
+ if is_datasets_available():
337
+ if isinstance(train_dataset, datasets.Dataset):
338
+ self._remove_unused_columns(self.train_dataset, description="training")
339
+ if isinstance(eval_dataset, datasets.Dataset):
340
+ self._remove_unused_columns(self.eval_dataset, description="evaluation")
341
+
342
+ self.global_step = None
343
+ self.epoch = None
344
+ self.total_flos = None
345
+ if self.args.fp16 and _use_native_amp:
346
+ self.scaler = torch.cuda.amp.GradScaler()
347
+ self.hp_search_backend = None
348
+ self.use_tune_checkpoints = False
349
+ if self.args.label_names is None:
350
+ self.args.label_names = (["labels"]
351
+ )
352
+
353
+ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
354
+ if not self.args.remove_unused_columns:
355
+ return
356
+ # Inspect model forward signature to keep only the arguments it accepts.
357
+ signature = inspect.signature(self.model.forward)
358
+ signature_columns = list(signature.parameters.keys())
359
+ # Labels may be named label or label_ids, the default data collator handles that.
360
+ signature_columns += ["label", "label_ids"]
361
+ columns = [k for k in signature_columns if k in dataset.column_names]
362
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
363
+ dset_description = "" if description is None else f"in the {description} set "
364
+ logger.info(
365
+ f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
366
+ )
367
+ dataset.set_format(type=dataset.format["type"], columns=columns)
368
+
369
+ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
370
+ if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
371
+ return None
372
+ elif is_torch_tpu_available():
373
+ return get_tpu_sampler(self.train_dataset)
374
+ else:
375
+ return (
376
+ RandomSampler(self.train_dataset)
377
+ if self.args.local_rank == -1
378
+ else DistributedSampler(self.train_dataset)
379
+ )
380
+
381
+ def get_train_dataloader(self) -> DataLoader:
382
+ """
383
+ Returns the training :class:`~torch.utils.data.DataLoader`.
384
+
385
+ Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
386
+ (adapted to distributed training if necessary) otherwise.
387
+
388
+ Subclass and override this method if you want to inject some custom behavior.
389
+ """
390
+ if self.train_dataset is None:
391
+ raise ValueError("Trainer: training requires a train_dataset.")
392
+ train_sampler = self._get_train_sampler()
393
+
394
+ return DataLoader(
395
+ self.train_dataset,
396
+ batch_size=self.args.train_batch_size,
397
+ sampler=train_sampler,
398
+ collate_fn=self.data_collator,
399
+ drop_last=self.args.dataloader_drop_last,
400
+ num_workers=self.args.dataloader_num_workers,
401
+ worker_init_fn=np.random.seed(self.args.seed)
402
+ )
403
+
404
+ def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
405
+ if isinstance(eval_dataset, torch.utils.data.IterableDataset):
406
+ return None
407
+ elif is_torch_tpu_available():
408
+ return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
409
+ elif self.args.local_rank != -1:
410
+ return SequentialDistributedSampler(eval_dataset)
411
+ else:
412
+ return SequentialSampler(eval_dataset)
413
+
414
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
415
+ """
416
+ Returns the evaluation :class:`~torch.utils.data.DataLoader`.
417
+
418
+ Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
419
+ sampler (adapted to distributed training if necessary) otherwise.
420
+
421
+ Subclass and override this method if you want to inject some custom behavior.
422
+
423
+ Args:
424
+ eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
425
+ If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
426
+ accepted by the ``model.forward()`` method are automatically removed.
427
+ """
428
+ if eval_dataset is None and self.eval_dataset is None:
429
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
430
+ elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
431
+ self._remove_unused_columns(eval_dataset, description="evaluation")
432
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
433
+ eval_sampler = self._get_eval_sampler(eval_dataset)
434
+
435
+ return DataLoader(
436
+ eval_dataset,
437
+ sampler=eval_sampler,
438
+ batch_size=self.args.eval_batch_size,
439
+ collate_fn=self.data_collator,
440
+ drop_last=self.args.dataloader_drop_last,
441
+ num_workers=self.args.dataloader_num_workers,
442
+ worker_init_fn=np.random.seed(self.args.seed)
443
+ )
444
+
445
+ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
446
+ """
447
+ Returns the test :class:`~torch.utils.data.DataLoader`.
448
+
449
+ Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
450
+ sampler (adapted to distributed training if necessary) otherwise.
451
+
452
+ Subclass and override this method if you want to inject some custom behavior.
453
+
454
+ Args:
455
+ eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
456
+ The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
457
+ ``model.forward()`` method are automatically removed.
458
+ """
459
+ if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
460
+ self._remove_unused_columns(test_dataset, description="test")
461
+ test_sampler = self._get_eval_sampler(test_dataset)
462
+
463
+ # We use the same batch_size as for eval.
464
+ return DataLoader(
465
+ test_dataset,
466
+ sampler=test_sampler,
467
+ batch_size=self.args.eval_batch_size,
468
+ collate_fn=self.data_collator,
469
+ drop_last=self.args.dataloader_drop_last,
470
+ worker_init_fn=np.random.seed(self.args.seed)
471
+ )
472
+
473
+ def create_optimizer_and_scheduler(self, num_training_steps: int):
474
+ """
475
+ Setup the optimizer and the learning rate scheduler.
476
+
477
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
478
+ Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
479
+ """
480
+ if self.optimizer is None:
481
+ no_decay = ["bias", "LayerNorm.weight"]
482
+ optimizer_grouped_parameters = [
483
+ {
484
+ "params": [p for n, p in self.model.named_parameters() if (not any(nd in n for nd in no_decay)) and p.requires_grad],
485
+ "weight_decay": self.args.weight_decay,
486
+ },
487
+ {
488
+ "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
489
+ "weight_decay": 0.0,
490
+ },
491
+ ]
492
+
493
+ self.optimizer = AdamW(
494
+ optimizer_grouped_parameters,
495
+ lr=self.args.learning_rate,
496
+ betas=(self.args.adam_beta1, self.args.adam_beta2),
497
+ eps=self.args.adam_epsilon,
498
+ )
499
+
500
+
501
+ # for n, p in self.model.named_parameters():
502
+ # print(n,p.requires_grad)
503
+ print(self.optimizer.state_dict())
504
+ if self.lr_scheduler is None:
505
+ self.lr_scheduler = get_linear_schedule_with_warmup(
506
+ self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
507
+ )
508
+
509
+
510
+ def setup_wandb(self):
511
+ """
512
+ Setup the optional Weights & Biases (`wandb`) integration.
513
+
514
+ One can subclass and override this method to customize the setup if needed. Find more information
515
+ `here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
516
+
517
+ Environment:
518
+ WANDB_WATCH:
519
+ (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
520
+ or "all" to log gradients and parameters
521
+ WANDB_PROJECT:
522
+ (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
523
+ WANDB_DISABLED:
524
+ (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
525
+ """
526
+ if hasattr(self, "_setup_wandb"):
527
+ warnings.warn(
528
+ "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
529
+ FutureWarning,
530
+ )
531
+ return self._setup_wandb()
532
+
533
+ if self.is_world_process_zero():
534
+ logger.info(
535
+ 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
536
+ )
537
+ try:
538
+ combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
539
+ except AttributeError:
540
+ # in case the model has no config
541
+ combined_dict = {**self.args.to_sanitized_dict()}
542
+ wandb.init(
543
+ project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
544
+ )
545
+ # keep track of model topology and gradients, unsupported on TPU
546
+ if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
547
+ wandb.watch(
548
+ self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
549
+ )
550
+
551
+ def setup_comet(self):
552
+ """
553
+ Setup the optional Comet.ml integration.
554
+
555
+ Environment:
556
+ COMET_MODE:
557
+ (Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
558
+ COMET_PROJECT_NAME:
559
+ (Optional): str - Comet.ml project name for experiments
560
+ COMET_OFFLINE_DIRECTORY:
561
+ (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
562
+
563
+ For a number of configurable items in the environment,
564
+ see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
565
+ """
566
+ if self.is_world_master():
567
+ comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
568
+ args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
569
+ experiment = None
570
+ if comet_mode == "ONLINE":
571
+ experiment = comet_ml.Experiment(**args)
572
+ logger.info("Automatic Comet.ml online logging enabled")
573
+ elif comet_mode == "OFFLINE":
574
+ args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
575
+ experiment = comet_ml.OfflineExperiment(**args)
576
+ logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
577
+ if experiment is not None:
578
+ experiment._set_model_graph(self.model, framework="transformers")
579
+ experiment._log_parameters(self.args, prefix="args/", framework="transformers")
580
+ experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
581
+
582
+ def num_examples(self, dataloader: DataLoader) -> int:
583
+ """
584
+ Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
585
+ """
586
+ return len(dataloader.dataset)
587
+
588
+ def _setup_loggers(self):
589
+ if self._loggers_initialized:
590
+ return
591
+ if is_wandb_available():
592
+ self.setup_wandb()
593
+ elif os.environ.get("WANDB_DISABLED") != "true":
594
+ logger.info(
595
+ "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
596
+ "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
597
+ )
598
+ if is_comet_available():
599
+ self.setup_comet()
600
+ elif os.environ.get("COMET_MODE") != "DISABLED":
601
+ logger.info(
602
+ "To use comet_ml logging, run `pip/conda install comet_ml` "
603
+ "see https://www.comet.ml/docs/python-sdk/huggingface/"
604
+ )
605
+ self._loggers_initialized = True
606
+
607
+ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
608
+ """ HP search setup code """
609
+ if self.hp_search_backend is None or trial is None:
610
+ return
611
+ params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
612
+ for key, value in params.items():
613
+ if not hasattr(self.args, key):
614
+ raise AttributeError(
615
+ f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
616
+ )
617
+ old_attr = getattr(self.args, key, None)
618
+ # Casting value to the proper type
619
+ if old_attr is not None:
620
+ value = type(old_attr)(value)
621
+ setattr(self.args, key, value)
622
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
623
+ logger.info("Trial:", trial.params)
624
+
625
+ def _report_to_hp_search(
626
+ self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
627
+ ):
628
+ if self.hp_search_backend is None or trial is None:
629
+ return
630
+ self.objective = self.compute_objective(metrics)
631
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
632
+ trial.report(self.objective, epoch)
633
+ if trial.should_prune():
634
+ raise optuna.TrialPruned()
635
+ elif self.hp_search_backend == HPSearchBackend.RAY:
636
+ if self.global_step % self.args.save_steps == 0:
637
+ self._tune_save_checkpoint()
638
+ tune.report(objective=self.objective, **metrics)
639
+
640
+ def _tune_save_checkpoint(self):
641
+ if not self.use_tune_checkpoints:
642
+ return
643
+ with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir:
644
+ self.args.output_dir = checkpoint_dir
645
+ output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
646
+ self.save_model(output_dir)
647
+ if self.is_world_master():
648
+ torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
649
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
650
+
651
+
652
+ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
653
+ """
654
+ Main training entry point.
655
+
656
+ Args:
657
+ model_path (:obj:`str`, `optional`):
658
+ Local path to the model if the model to train has been instantiated from a local path. If present,
659
+ training will resume from the optimizer/scheduler states loaded here.
660
+ trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
661
+ The trial run or the hyperparameter dictionary for hyperparameter search.
662
+ """
663
+ # This might change the seed so needs to run first.
664
+ self._hp_search_setup(trial)
665
+
666
+ # Model re-init
667
+ if self.model_init is not None:
668
+ # Seed must be set before instantiating the model when using model_init.
669
+ set_seed(self.args.seed)
670
+ model = self.model_init()
671
+ self.model = model.to(self.args.device)
672
+
673
+ # Reinitializes optimizer and scheduler
674
+ self.optimizer, self.lr_scheduler = None, None
675
+
676
+ # Data loader and number of training steps
677
+ train_dataloader = self.get_train_dataloader()
678
+ num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
679
+ num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
680
+ if self.args.max_steps > 0:
681
+ t_total = self.args.max_steps
682
+ num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
683
+ self.args.max_steps % num_update_steps_per_epoch > 0
684
+ )
685
+ else:
686
+ t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
687
+ num_train_epochs = self.args.num_train_epochs
688
+ self.args.max_steps = t_total
689
+
690
+ self.create_optimizer_and_scheduler(num_training_steps=t_total)
691
+
692
+ # Check if saved optimizer or scheduler states exist
693
+ if (
694
+ model_path is not None
695
+ and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
696
+ and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
697
+ ):
698
+ # Load in optimizer and scheduler states
699
+ self.optimizer.load_state_dict(
700
+ torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
701
+ )
702
+ self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
703
+
704
+ model = self.model
705
+ if self.args.fp16 and _use_apex:
706
+ if not is_apex_available():
707
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
708
+ model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
709
+
710
+ # multi-gpu training (should be after apex fp16 initialization)
711
+ if self.args.n_gpu > 1:
712
+ model = torch.nn.DataParallel(model)
713
+
714
+ # Distributed training (should be after apex fp16 initialization)
715
+ if self.args.local_rank != -1:
716
+ model = torch.nn.parallel.DistributedDataParallel(
717
+ model,
718
+ device_ids=[self.args.local_rank],
719
+ output_device=self.args.local_rank,
720
+ find_unused_parameters=True,
721
+ )
722
+
723
+ if self.tb_writer is not None:
724
+ self.tb_writer.add_text("args", self.args.to_json_string())
725
+ self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
726
+
727
+ # Train!
728
+ if is_torch_tpu_available():
729
+ total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
730
+ else:
731
+ total_train_batch_size = (
732
+ self.args.train_batch_size
733
+ * self.args.gradient_accumulation_steps
734
+ * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
735
+ )
736
+ logger.info("***** Running training *****")
737
+ logger.info(" Num examples = %d", self.num_examples(train_dataloader))
738
+ logger.info(" Num Epochs = %d", num_train_epochs)
739
+ logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
740
+ logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
741
+ logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
742
+ logger.info(" Total optimization steps = %d", t_total)
743
+
744
+ self.global_step = 0
745
+ self.epoch = 0
746
+ self.total_flos = 0
747
+ epochs_trained = 0
748
+ steps_trained_in_current_epoch = 0
749
+ # Check if continuing training from a checkpoint
750
+ if model_path is not None:
751
+ # set global_step to global_step of last saved checkpoint from model path
752
+ try:
753
+ self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
754
+ # print(model, model.module)
755
+ if self.args.n_gpu > 1:
756
+ self.total_flos = getattr(model.module.config, "total_flos", 0)
757
+ else:
758
+ self.total_flos = getattr(model.config, "total_flos", 0)
759
+
760
+ epochs_trained = self.global_step // num_update_steps_per_epoch
761
+ steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
762
+
763
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
764
+ logger.info(" Continuing training from epoch %d", epochs_trained)
765
+ logger.info(" Continuing training from global step %d", self.global_step)
766
+ logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
767
+ logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
768
+ except ValueError:
769
+ self.global_step = 0
770
+ self.total_flos = 0
771
+ logger.info(" Starting fine-tuning.")
772
+
773
+ tr_loss = torch.tensor(0.0).to(self.args.device)
774
+ logging_loss_scalar = 0.0
775
+ model.zero_grad()
776
+ disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
777
+ train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
778
+ for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
779
+ if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
780
+ train_dataloader.sampler.set_epoch(epoch)
781
+
782
+ if is_torch_tpu_available():
783
+ parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
784
+ self.args.device
785
+ )
786
+ epoch_iterator = parallel_loader
787
+ else:
788
+ epoch_iterator = train_dataloader
789
+
790
+ # Reset the past mems state at the beginning of each epoch if necessary.
791
+ if self.args.past_index >= 0:
792
+ self._past = None
793
+
794
+ epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
795
+ for step, inputs in enumerate(epoch_iterator):
796
+
797
+ # Skip past any already trained steps if resuming training
798
+ if steps_trained_in_current_epoch > 0:
799
+ steps_trained_in_current_epoch -= 1
800
+ epoch_pbar.update(1)
801
+ continue
802
+
803
+ tr_loss += self.training_step(model, inputs)
804
+
805
+ self.total_flos += self.floating_point_ops(inputs)
806
+
807
+ if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
808
+ # last step in epoch but step is always smaller than gradient_accumulation_steps
809
+ len(epoch_iterator) <= self.args.gradient_accumulation_steps
810
+ and (step + 1) == len(epoch_iterator)
811
+ ):
812
+ if self.args.fp16 and _use_native_amp:
813
+ self.scaler.unscale_(self.optimizer)
814
+ torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
815
+ elif self.args.fp16 and _use_apex:
816
+ torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
817
+ else:
818
+ torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
819
+
820
+ if is_torch_tpu_available():
821
+ xm.optimizer_step(self.optimizer)
822
+ elif self.args.fp16 and _use_native_amp:
823
+ self.scaler.step(self.optimizer)
824
+ self.scaler.update()
825
+ else:
826
+ self.optimizer.step()
827
+
828
+ # URGENT
829
+ self.lr_scheduler.step()
830
+ model.zero_grad()
831
+ self.global_step += 1
832
+ self.epoch = epoch + (step + 1) / len(epoch_iterator)
833
+
834
+
835
+ if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
836
+ self.global_step == 1 and self.args.logging_first_step
837
+ ):
838
+ logs: Dict[str, float] = {}
839
+ tr_loss_scalar = tr_loss.item()
840
+ logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
841
+ # backward compatibility for pytorch schedulers
842
+ logs["learning_rate"] = (
843
+ self.lr_scheduler.get_last_lr()[0]
844
+ if version.parse(torch.__version__) >= version.parse("1.4")
845
+ else self.lr_scheduler.get_lr()[0]
846
+ )
847
+ logging_loss_scalar = tr_loss_scalar
848
+
849
+ self.log(logs)
850
+
851
+ # print(self.args.evaluation_strategy == EvaluationStrategy.STEPS )
852
+ # print(self.global_step % self.args.eval_steps == 0)
853
+ # print()
854
+
855
+ if (
856
+ self.args.evaluation_strategy == EvaluationStrategy.STEPS
857
+ and self.global_step % self.args.eval_steps == 0
858
+ ):
859
+ metrics = self.evaluate()
860
+ self._report_to_hp_search(trial, epoch, metrics)
861
+
862
+ #############################EARLY STOPPING########################
863
+ if 'lowdata' in self.args.output_dir or 'earlystop' in self.args.output_dir:
864
+ self.save_based_on_eval = True
865
+ else:
866
+ self.save_based_on_eval = False
867
+ print('if not see a line lowdata: below, then did not go into low data. ')
868
+ if self.save_based_on_eval and metrics["eval_loss"] < self.curr_best_eval:
869
+ print('lowdata:', self.global_step, self.curr_best_eval, metrics["eval_loss"],
870
+ 'perplexity={}'.format(math.exp(metrics["eval_loss"])))
871
+ self.curr_best_eval = metrics["eval_loss"]
872
+ if hasattr(model, "module"):
873
+ assert (
874
+ model.module is self.model
875
+ ), f"Module {model.module} should be a reference to self.model"
876
+ else:
877
+ assert model is self.model, f"Model {model} should be a reference to self.model"
878
+ # Save model checkpoint
879
+ output_dir_name = os.path.basename(self.args.output_dir)
880
+ checkpoint_folder = f"{output_dir_name}-{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
881
+ if self.hp_search_backend is not None and trial is not None:
882
+ run_id = (
883
+ trial.number
884
+ if self.hp_search_backend == HPSearchBackend.OPTUNA
885
+ else tune.get_trial_id()
886
+ )
887
+ checkpoint_folder += f"-run-{run_id}"
888
+ output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
889
+
890
+ self.store_flos()
891
+ print('saving to output_dir', output_dir)
892
+ self.save_model(output_dir)
893
+
894
+ if self.is_world_process_zero():
895
+ self._rotate_checkpoints(use_mtime=True)
896
+ #####################################################
897
+
898
+ if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
899
+ print('saving model at a checkpoint!!')
900
+ # In all cases (even distributed/parallel), self.model is always a reference
901
+ # to the model we want to save.
902
+ if hasattr(model, "module"):
903
+ assert (
904
+ model.module is self.model
905
+ ), f"Module {model.module} should be a reference to self.model"
906
+ else:
907
+ assert model is self.model, f"Model {model} should be a reference to self.model"
908
+ # Save model checkpoint
909
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
910
+ if self.hp_search_backend is not None and trial is not None:
911
+ run_id = (
912
+ trial.number
913
+ if self.hp_search_backend == HPSearchBackend.OPTUNA
914
+ else tune.get_trial_id()
915
+ )
916
+ checkpoint_folder += f"-run-{run_id}"
917
+ output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
918
+
919
+ self.store_flos()
920
+
921
+ self.save_model(output_dir)
922
+
923
+ if self.is_world_process_zero():
924
+ self._rotate_checkpoints(use_mtime=True)
925
+
926
+ if is_torch_tpu_available():
927
+ xm.rendezvous("saving_optimizer_states")
928
+ xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
929
+ xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
930
+ elif self.is_world_process_zero():
931
+ torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
932
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
933
+
934
+ epoch_pbar.update(1)
935
+ if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
936
+ break
937
+ epoch_pbar.close()
938
+ train_pbar.update(1)
939
+
940
+ if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
941
+ metrics = self.evaluate()
942
+ self._report_to_hp_search(trial, epoch, metrics)
943
+
944
+ if self.args.tpu_metrics_debug or self.args.debug:
945
+ if is_torch_tpu_available():
946
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
947
+ xm.master_print(met.metrics_report())
948
+ else:
949
+ logger.warning(
950
+ "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
951
+ "configured. Check your training configuration if this is unexpected."
952
+ )
953
+ if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
954
+ break
955
+
956
+ train_pbar.close()
957
+ if self.tb_writer:
958
+ self.tb_writer.close()
959
+ if self.args.past_index and hasattr(self, "_past"):
960
+ # Clean the state at the end of training
961
+ delattr(self, "_past")
962
+
963
+ logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
964
+ return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
965
+
966
+ def hyperparameter_search(
967
+ self,
968
+ hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
969
+ compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
970
+ n_trials: int = 20,
971
+ direction: str = "minimize",
972
+ backend: Optional[Union["str", HPSearchBackend]] = None,
973
+ **kwargs
974
+ ) -> BestRun:
975
+ """
976
+ Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
977
+ :obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
978
+ the sum of all metrics otherwise.
979
+
980
+ .. warning::
981
+
982
+ To use this method, you need to have provided a ``model_init`` when initializing your
983
+ :class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
984
+ with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
985
+ method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.
986
+
987
+ Args:
988
+ hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
989
+ A function that defines the hyperparameter search space. Will default to
990
+ :func:`~transformers.trainer_utils.default_hp_space_optuna` or
991
+ :func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
992
+ compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
993
+ A function computing the objective to minimize or maximize from the metrics returned by the
994
+ :obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
995
+ n_trials (:obj:`int`, `optional`, defaults to 100):
996
+ The number of trial runs to test.
997
+ direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
998
+ Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
999
+ pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
1000
+ several metrics.
1001
+ backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
1002
+ The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
1003
+ one is installed. If both are installed, will default to optuna.
1004
+ kwargs:
1005
+ Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
1006
+ more information see:
1007
+
1008
+ - the documentation of `optuna.create_study <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
1009
+ - the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
1010
+
1011
+ Returns:
1012
+ :class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
1013
+ """
1014
+ if backend is None:
1015
+ backend = default_hp_search_backend()
1016
+ if backend is None:
1017
+ raise RuntimeError(
1018
+ "At least one of optuna or ray should be installed. "
1019
+ "To install optuna run `pip install optuna`."
1020
+ "To install ray run `pip install ray[tune]`."
1021
+ )
1022
+ backend = HPSearchBackend(backend)
1023
+ if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
1024
+ raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
1025
+ if backend == HPSearchBackend.RAY and not is_ray_available():
1026
+ raise RuntimeError(
1027
+ "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
1028
+ )
1029
+ self.hp_search_backend = backend
1030
+
1031
+ if self.model_init is None:
1032
+ raise RuntimeError(
1033
+ "To use hyperparameter search, you need to pass your model through a model_init function."
1034
+ )
1035
+
1036
+ self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
1037
+ self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
1038
+
1039
+ run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
1040
+ best_run = run_hp_search(self, n_trials, direction, **kwargs)
1041
+
1042
+ self.hp_search_backend = None
1043
+ return best_run
1044
+
1045
+ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
1046
+ """
1047
+ Log :obj:`logs` on the various objects watching training.
1048
+
1049
+ Subclass and override this method to inject custom behavior.
1050
+
1051
+ Args:
1052
+ logs (:obj:`Dict[str, float]`):
1053
+ The values to log.
1054
+ iterator (:obj:`tqdm`, `optional`):
1055
+ A potential tqdm progress bar to write the logs on.
1056
+ """
1057
+ # Set up loggers like W&B or Comet ML
1058
+ self._setup_loggers()
1059
+
1060
+ if hasattr(self, "_log"):
1061
+ warnings.warn(
1062
+ "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
1063
+ FutureWarning,
1064
+ )
1065
+ return self._log(logs, iterator=iterator)
1066
+
1067
+ if self.epoch is not None:
1068
+ logs["epoch"] = self.epoch
1069
+ if self.total_flos is not None:
1070
+ if self.args.local_rank != -1:
1071
+ total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
1072
+ else:
1073
+ total_flos = self.total_flos
1074
+ if total_flos > 0:
1075
+ logs["total_flos"] = self.total_flos
1076
+ if self.global_step is None:
1077
+ # when logging evaluation metrics without training
1078
+ self.global_step = 0
1079
+ if self.tb_writer:
1080
+ for k, v in logs.items():
1081
+ if isinstance(v, (int, float)):
1082
+ self.tb_writer.add_scalar(k, v, self.global_step)
1083
+ else:
1084
+ logger.warning(
1085
+ "Trainer is attempting to log a value of "
1086
+ '"%s" of type %s for key "%s" as a scalar. '
1087
+ "This invocation of Tensorboard's writer.add_scalar() "
1088
+ "is incorrect so we dropped this attribute.",
1089
+ v,
1090
+ type(v),
1091
+ k,
1092
+ )
1093
+ self.tb_writer.flush()
1094
+ if is_wandb_available():
1095
+ if self.is_world_process_zero():
1096
+ wandb.log(logs, step=self.global_step)
1097
+ if is_comet_available():
1098
+ if self.is_world_process_zero():
1099
+ experiment = comet_ml.config.get_global_experiment()
1100
+ if experiment is not None:
1101
+ experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
1102
+ output = {**logs, **{"step": self.global_step}}
1103
+ if self.is_world_process_zero():
1104
+ self.log_history.append(output)
1105
+ if iterator is not None:
1106
+ iterator.write(output)
1107
+ else:
1108
+ print(output)
1109
+
1110
+ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
1111
+ """
1112
+ Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
1113
+ handling potential state.
1114
+ """
1115
+ for k, v in inputs.items():
1116
+ if isinstance(v, torch.Tensor):
1117
+ inputs[k] = v.to(self.args.device)
1118
+
1119
+ if self.args.past_index >= 0 and self._past is not None:
1120
+ assert False
1121
+ inputs["mems"] = self._past
1122
+
1123
+ return inputs
1124
+
1125
+ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
1126
+ """
1127
+ Perform a training step on a batch of inputs.
1128
+
1129
+ Subclass and override to inject custom behavior.
1130
+
1131
+ Args:
1132
+ model (:obj:`nn.Module`):
1133
+ The model to train.
1134
+ inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
1135
+ The inputs and targets of the model.
1136
+
1137
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
1138
+ argument :obj:`labels`. Check your model's documentation for all accepted arguments.
1139
+
1140
+ Return:
1141
+ :obj:`torch.Tensor`: The tensor with training loss on this batch.
1142
+ """
1143
+ if hasattr(self, "_training_step"):
1144
+ warnings.warn(
1145
+ "The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
1146
+ FutureWarning,
1147
+ )
1148
+ return self._training_step(model, inputs, self.optimizer)
1149
+
1150
+ model.train()
1151
+ if self.use_dropout:
1152
+ if self.gpt2 is not None:
1153
+ self.gpt2.train()
1154
+ inputs = self._prepare_inputs(inputs)
1155
+
1156
+ if self.args.fp16 and _use_native_amp:
1157
+ with autocast():
1158
+ if self.distill:
1159
+ loss = self.compute_loss_distill(model, inputs, gpt2_model=self.gpt2, )
1160
+ else:
1161
+ loss = self.compute_loss(model, inputs, gpt2_model=self.gpt2)
1162
+ else:
1163
+ if self.distill:
1164
+ loss = self.compute_loss_distill(model, inputs, gpt2_model=self.gpt2)
1165
+ else:
1166
+ loss = self.compute_loss(model, inputs, gpt2_model=self.gpt2)
1167
+
1168
+ if self.args.n_gpu > 1:
1169
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
1170
+
1171
+ if self.args.gradient_accumulation_steps > 1:
1172
+ loss = loss / self.args.gradient_accumulation_steps
1173
+
1174
+ if self.args.fp16 and _use_native_amp:
1175
+ self.scaler.scale(loss).backward()
1176
+ elif self.args.fp16 and _use_apex:
1177
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
1178
+ scaled_loss.backward()
1179
+ else:
1180
+ # print(loss)
1181
+ loss.backward()
1182
+
1183
+ # print('max allocated_memory:', torch.cuda.max_memory_allocated(0), 'total_memory:', torch.cuda.get_device_properties(0).total_memory,
1184
+ # 'percentage', torch.cuda.max_memory_allocated(0)/torch.cuda.get_device_properties(0).total_memory)
1185
+
1186
+
1187
+ return loss.detach()
1188
+
1189
+
1190
+
1191
+
1192
+
1193
+ def compute_loss(self, model, inputs, gpt2_model=None):
1194
+ """
1195
+ How the loss is computed by Trainer. By default, all models return the loss in the first element.
1196
+
1197
+ Subclass and override for custom behavior.
1198
+ """
1199
+ # outputs = model.forward_weighted(**inputs)
1200
+ if 'prompt_lab' in inputs:
1201
+ prompt_lab_ = inputs['prompt_lab']
1202
+ k = torch.cat(self.discri_labels_code, dim=0)
1203
+ inputs['control_code'] = torch.index_select(k, 0, prompt_lab_)
1204
+ del inputs['prompt_lab']
1205
+
1206
+ outputs = model(**inputs, gpt2_model=gpt2_model)
1207
+ # Save past state if it exists
1208
+ if self.args.past_index >= 0:
1209
+ self._past = outputs[self.args.past_index]
1210
+
1211
+ # print(outputs[0])
1212
+ # We don't use .loss here since the model may return tuples instead of ModelOutput.
1213
+ # print(outputs[0], outputs.loss)
1214
+ # URGENT
1215
+ # print('compute_loss', outputs[0])
1216
+ return outputs[0].mean()
1217
+
1218
+ def compute_loss_distill(self, model, inputs, gpt2_model=None):
1219
+ """
1220
+ How the loss is computed by Trainer. By default, all models return the loss in the first element.
1221
+
1222
+ Subclass and override for custom behavior.
1223
+ """
1224
+ # outputs = model.forward_weighted(**inputs)
1225
+
1226
+ with torch.no_grad():
1227
+ output_finetuned = self.finetuned_gpt2(**inputs)
1228
+
1229
+ outputs = model(**inputs, gpt2_model=gpt2_model)
1230
+ # Save past state if it exists
1231
+ if self.args.past_index >= 0:
1232
+ self._past = outputs[self.args.past_index]
1233
+
1234
+ if self.matching_objective == 'kl':
1235
+ # distrib_finetuned=torch.log_softmax(output_finetuned.logits[:,:,:-2], dim=-1) #bsz, seqlen, vocab
1236
+ distrib_finetuned=torch.log_softmax(output_finetuned.logits, dim=-1) #bsz, seqlen, vocab
1237
+ distrib_prefix = torch.log_softmax(outputs.logits, dim=-1) # bsz, seqlen, vocab
1238
+ loss = torch.sum(distrib_finetuned.exp() * (distrib_finetuned - distrib_prefix), dim=-1) #bsz, seqlen
1239
+
1240
+ elif self.matching_objective == 'logits':
1241
+ loss = torch.norm(output_finetuned.logits - outputs.logits, dim=-1) #bsz, seqlen
1242
+ # loss = torch.norm(output_finetuned.logits[:,:,:-2] - outputs.logits, dim=-1) #bsz, seqlen
1243
+
1244
+ elif self.matching_objective == 'last_layer':
1245
+ activation_diff = output_finetuned.last_hidden_state - outputs.last_hidden_state
1246
+ loss = torch.norm(activation_diff, dim=-1) # bsz, seqlen
1247
+ else:
1248
+ assert False, "invalid matching_objective"
1249
+
1250
+ return loss.sum(dim=-1).mean()
1251
+
1252
+ def is_local_master(self) -> bool:
1253
+ """
1254
+ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
1255
+ several machines) main process.
1256
+
1257
+ .. warning::
1258
+
1259
+ This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
1260
+ """
1261
+ warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
1262
+ return self.is_local_process_zero()
1263
+
1264
+ def is_local_process_zero(self) -> bool:
1265
+ """
1266
+ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
1267
+ several machines) main process.
1268
+ """
1269
+ if is_torch_tpu_available():
1270
+ return xm.is_master_ordinal(local=True)
1271
+ else:
1272
+ return self.args.local_rank in [-1, 0]
1273
+
1274
+ def is_world_master(self) -> bool:
1275
+ """
1276
+ Whether or not this process is the global main process (when training in a distributed fashion on
1277
+ several machines, this is only going to be :obj:`True` for one process).
1278
+
1279
+ .. warning::
1280
+
1281
+ This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
1282
+ """
1283
+ warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
1284
+ return self.is_world_process_zero()
1285
+
1286
+ def is_world_process_zero(self) -> bool:
1287
+ """
1288
+ Whether or not this process is the global main process (when training in a distributed fashion on
1289
+ several machines, this is only going to be :obj:`True` for one process).
1290
+ """
1291
+ if is_torch_tpu_available():
1292
+ return xm.is_master_ordinal(local=False)
1293
+ else:
1294
+ return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
1295
+
1296
+ def save_model(self, output_dir: Optional[str] = None):
1297
+ """
1298
+ Will save the model, so you can reload it using :obj:`from_pretrained()`.
1299
+
1300
+ Will only save from the world_master process (unless in TPUs).
1301
+ """
1302
+
1303
+ if is_torch_tpu_available():
1304
+ self._save_tpu(output_dir)
1305
+ elif self.is_world_process_zero():
1306
+ self._save(output_dir)
1307
+
1308
+ def _save_tpu(self, output_dir: Optional[str] = None):
1309
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
1310
+ logger.info("Saving model checkpoint to %s", output_dir)
1311
+
1312
+ if xm.is_master_ordinal():
1313
+ os.makedirs(output_dir, exist_ok=True)
1314
+ torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1315
+ json.dump(
1316
+ self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
1317
+ )
1318
+
1319
+ # Save a trained model and configuration using `save_pretrained()`.
1320
+ # They can then be reloaded using `from_pretrained()`
1321
+ if not isinstance(self.model, PreTrainedModel):
1322
+ raise ValueError("Trainer.model appears to not be a PreTrainedModel")
1323
+
1324
+ xm.rendezvous("saving_checkpoint")
1325
+ self.model.save_pretrained(output_dir)
1326
+ if self.tokenizer is not None:
1327
+ self.tokenizer.save_pretrained(output_dir)
1328
+
1329
+ def _save(self, output_dir: Optional[str] = None):
1330
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
1331
+ os.makedirs(output_dir, exist_ok=True)
1332
+ logger.info("Saving model checkpoint to %s", output_dir)
1333
+ # Save a trained model and configuration using `save_pretrained()`.
1334
+ # They can then be reloaded using `from_pretrained()`
1335
+ if not isinstance(self.model, PreTrainedModel):
1336
+ raise ValueError("Trainer.model appears to not be a PreTrainedModel")
1337
+ self.model.save_pretrained(output_dir)
1338
+ if self.tokenizer is not None:
1339
+ self.tokenizer.save_pretrained(output_dir)
1340
+
1341
+ # Good practice: save your training arguments together with the trained model
1342
+ torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1343
+ json.dump(
1344
+ self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
1345
+ )
1346
+
1347
+ def store_flos(self):
1348
+ # Storing the number of floating-point operations that went into the model
1349
+ if self.total_flos is not None:
1350
+ if self.args.local_rank != -1:
1351
+ total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
1352
+ else:
1353
+ total_flos = self.total_flos
1354
+ if total_flos > 0:
1355
+ self.model.config.total_flos = total_flos
1356
+
1357
+ def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
1358
+ output_dir_name = os.path.basename(self.args.output_dir)
1359
+ checkpoint_prefix = f"{output_dir_name}-{PREFIX_CHECKPOINT_DIR}"
1360
+
1361
+ ordering_and_checkpoint_path = []
1362
+
1363
+ glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
1364
+
1365
+ for path in glob_checkpoints:
1366
+ if use_mtime:
1367
+ ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
1368
+ else:
1369
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
1370
+ if regex_match and regex_match.groups():
1371
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
1372
+
1373
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
1374
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
1375
+ return checkpoints_sorted
1376
+
1377
+ def _rotate_checkpoints(self, use_mtime=False) -> None:
1378
+ if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
1379
+ return
1380
+
1381
+ # Check if we should delete older checkpoint(s)
1382
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
1383
+ if len(checkpoints_sorted) <= self.args.save_total_limit:
1384
+ return
1385
+
1386
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
1387
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
1388
+ for checkpoint in checkpoints_to_be_deleted:
1389
+ logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
1390
+ shutil.rmtree(checkpoint)
1391
+
1392
+ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
1393
+ """
1394
+ Run evaluation and returns metrics.
1395
+
1396
+ The calling script will be responsible for providing a method to compute metrics, as they are
1397
+ task-dependent (pass it to the init :obj:`compute_metrics` argument).
1398
+
1399
+ You can also subclass and override this method to inject custom behavior.
1400
+
1401
+ Args:
1402
+ eval_dataset (:obj:`Dataset`, `optional`):
1403
+ Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
1404
+ columns not accepted by the ``model.forward()`` method are automatically removed.
1405
+
1406
+ Returns:
1407
+ A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
1408
+ """
1409
+ eval_dataloader = self.get_eval_dataloader(eval_dataset)
1410
+
1411
+ output = self.prediction_loop(eval_dataloader, description="Evaluation")
1412
+
1413
+ self.log(output.metrics)
1414
+
1415
+ if self.args.tpu_metrics_debug or self.args.debug:
1416
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
1417
+ xm.master_print(met.metrics_report())
1418
+
1419
+ return output.metrics
1420
+
1421
+
1422
+
1423
+ def predict(self, test_dataset: Dataset) -> PredictionOutput:
1424
+ """
1425
+ Run prediction and returns predictions and potential metrics.
1426
+
1427
+ Depending on the dataset and your use case, your test dataset may contain labels.
1428
+ In that case, this method will also return metrics, like in :obj:`evaluate()`.
1429
+
1430
+ Args:
1431
+ test_dataset (:obj:`Dataset`):
1432
+ Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
1433
+ ``model.forward()`` method are automatically removed.
1434
+
1435
+ Returns:
1436
+ `NamedTuple`:
1437
+ predictions (:obj:`np.ndarray`):
1438
+ The predictions on :obj:`test_dataset`.
1439
+ label_ids (:obj:`np.ndarray`, `optional`):
1440
+ The labels (if the dataset contained some).
1441
+ metrics (:obj:`Dict[str, float]`, `optional`):
1442
+ The potential dictionary of metrics (if the dataset contained labels).
1443
+ """
1444
+ test_dataloader = self.get_test_dataloader(test_dataset)
1445
+
1446
+ return self.prediction_loop(test_dataloader, description="Prediction")
1447
+
1448
+ def prediction_loop(
1449
+ self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
1450
+ ) -> PredictionOutput:
1451
+ """
1452
+ Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
1453
+
1454
+ Works both with or without labels.
1455
+ """
1456
+ if hasattr(self, "_prediction_loop"):
1457
+ warnings.warn(
1458
+ "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
1459
+ FutureWarning,
1460
+ )
1461
+ return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
1462
+
1463
+ prediction_loss_only = (
1464
+ prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
1465
+ )
1466
+
1467
+ assert not getattr(
1468
+ self.model.config, "output_attentions", False
1469
+ ), "The prediction loop does not work with `output_attentions=True`."
1470
+ assert not getattr(
1471
+ self.model.config, "output_hidden_states", False
1472
+ ), "The prediction loop does not work with `output_hidden_states=True`."
1473
+
1474
+ model = self.model
1475
+ # multi-gpu eval
1476
+ if self.args.n_gpu > 1:
1477
+ model = torch.nn.DataParallel(model)
1478
+ else:
1479
+ model = self.model
1480
+ # Note: in torch.distributed mode, there's no point in wrapping the model
1481
+ # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
1482
+
1483
+ batch_size = dataloader.batch_size
1484
+ logger.info("***** Running %s *****", description)
1485
+ logger.info(" Num examples = %d", self.num_examples(dataloader))
1486
+ logger.info(" Batch size = %d", batch_size)
1487
+ eval_losses: List[float] = []
1488
+ preds: torch.Tensor = None
1489
+ label_ids: torch.Tensor = None
1490
+ entropy_losses: List[float] = []
1491
+ model.eval()
1492
+ if self.gpt2 is not None:
1493
+ self.gpt2.eval()
1494
+
1495
+ print(model.training)
1496
+ print(self.gpt2.training)
1497
+
1498
+ if is_torch_tpu_available():
1499
+ dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
1500
+
1501
+ if self.args.past_index >= 0:
1502
+ self._past = None
1503
+
1504
+ disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
1505
+ for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
1506
+ loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
1507
+ batch_size = inputs[list(inputs.keys())[0]].shape[0]
1508
+ if loss is not None:
1509
+ eval_losses.extend([loss] * batch_size)
1510
+ if logits is not None:
1511
+ preds = logits if preds is None else nested_concat(preds, logits, dim=0)
1512
+ temp_logits = [torch.log_softmax(x) for x in logits]
1513
+ entropy_losses.extend([(x.exp() * x).sum() for x in temp_logits])
1514
+ if labels is not None:
1515
+ label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
1516
+
1517
+ if self.args.past_index and hasattr(self, "_past"):
1518
+ # Clean the state at the end of the evaluation loop
1519
+ delattr(self, "_past")
1520
+
1521
+
1522
+
1523
+ if self.compute_metrics is not None and preds is not None and label_ids is not None:
1524
+ metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
1525
+ else:
1526
+ metrics = {}
1527
+
1528
+ # Prefix all keys with eval_
1529
+ for key in list(metrics.keys()):
1530
+ if not key.startswith("eval_"):
1531
+ metrics[f"eval_{key}"] = metrics.pop(key)
1532
+ if len(entropy_losses) > 0:
1533
+ metrics['entropy'] = np.mean(entropy_losses)
1534
+ print('entropy', metrics['entropy'] )
1535
+
1536
+ return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
1537
+
1538
+ def prediction_step(
1539
+ self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
1540
+ ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
1541
+ """
1542
+ Perform an evaluation step on :obj:`model` using obj:`inputs`.
1543
+
1544
+ Subclass and override to inject custom behavior.
1545
+
1546
+ Args:
1547
+ model (:obj:`nn.Module`):
1548
+ The model to evaluate.
1549
+ inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
1550
+ The inputs and targets of the model.
1551
+
1552
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
1553
+ argument :obj:`labels`. Check your model's documentation for all accepted arguments.
1554
+ prediction_loss_only (:obj:`bool`):
1555
+ Whether or not to return the loss only.
1556
+
1557
+ Return:
1558
+ Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
1559
+ A tuple with the loss, logits and labels (each being optional).
1560
+ """
1561
+ has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
1562
+ inputs = self._prepare_inputs(inputs)
1563
+
1564
+ # At eval time, set the weights to 1/bsz. and see the results..
1565
+
1566
+ # if 'weights' in inputs:
1567
+ # weights = inputs['weights']
1568
+ # bsz = weights.view(-1).shape[0]
1569
+ # weights = (torch.ones(weights.shape)/bsz).to(weights.device)
1570
+ # inputs['weights'] = weights
1571
+
1572
+ with torch.no_grad():
1573
+ # outputs = model.forward_weighted(**inputs)
1574
+ outputs = model(**inputs, gpt2_model=self.gpt2)
1575
+ if has_labels:
1576
+ # The .mean() is to reduce in case of distributed training
1577
+ loss = outputs[0].mean().item()
1578
+ logits = outputs[1:]
1579
+ else:
1580
+ loss = None
1581
+ # Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
1582
+ logits = outputs[:]
1583
+ if self.args.past_index >= 0:
1584
+ self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
1585
+
1586
+ if prediction_loss_only:
1587
+ return (loss, None, None)
1588
+
1589
+ logits = tuple(logit.detach() for logit in logits)
1590
+ if len(logits) == 1:
1591
+ logits = logits[0]
1592
+
1593
+ if has_labels:
1594
+ labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
1595
+ if len(labels) == 1:
1596
+ labels = labels[0]
1597
+ else:
1598
+ labels = None
1599
+
1600
+ return (loss, logits, labels)
1601
+
1602
+ def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
1603
+ """
1604
+ For models that inherit from :class:`~transformers.PretrainedModel`, uses
1605
+ that method to compute the number of floating point operations for every backward + forward pass. If using
1606
+ another model, either implement such a method in the model or subclass and override this method.
1607
+
1608
+ Args:
1609
+ model (:obj:`nn.Module`):
1610
+ The model to evaluate.
1611
+ inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
1612
+ The inputs and targets of the model.
1613
+
1614
+ Returns:
1615
+ :obj:`int`: The number of floating-point operations.
1616
+ """
1617
+
1618
+ if isinstance(self.model, torch.nn.DataParallel) or isinstance(
1619
+ self.model, torch.nn.parallel.DistributedDataParallel
1620
+ ):
1621
+ model = self.model.module
1622
+ else:
1623
+ model = self.model
1624
+
1625
+ if hasattr(model, "floating_point_ops"):
1626
+ return model.floating_point_ops(inputs)
1627
+
1628
+ else:
1629
+ return 0
dalle/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .utils import *
2
+ from .config import *
3
+ from .sampling import *
dalle/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (214 Bytes). View file
 
dalle/utils/__pycache__/config.cpython-38.pyc ADDED
Binary file (7.78 kB). View file
 
dalle/utils/__pycache__/sampling.cpython-38.pyc ADDED
Binary file (6.86 kB). View file
 
dalle/utils/__pycache__/utils.cpython-38.pyc ADDED
Binary file (3.62 kB). View file
 
dalle/utils/config.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ from typing import Optional, List
8
+ from dataclasses import dataclass, field
9
+ from omegaconf import OmegaConf
10
+
11
+
12
+ @dataclass
13
+ class DataConfig:
14
+ dataset: Optional[str] = None
15
+ tokenizer_type: str = 'CharBPE'
16
+ context_length: int = 64
17
+ image_resolution: int = 256
18
+ transforms: str = 'dalle-vqvae'
19
+ bpe_pdrop: Optional[float] = None
20
+
21
+
22
+ @dataclass
23
+ class Stage1Hparams:
24
+ double_z: bool = False
25
+ z_channels: int = 256
26
+ resolution: int = 256
27
+ in_channels: int = 3
28
+ out_ch: int = 3
29
+ ch: int = 128
30
+ ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
31
+ num_res_blocks: int = 2
32
+ attn_resolutions: List[int] = field(default_factory=lambda: [16])
33
+ pdrop: float = 0.0
34
+
35
+
36
+ @dataclass
37
+ class Stage2Hparams:
38
+ embed_dim: int = 1536
39
+ n_layers: int = 42
40
+ n_heads: int = 24
41
+ n_dense_layers: int = 42
42
+ ctx_len_img: int = 256
43
+ ctx_len_txt: int = 64
44
+ embd_pdrop: float = 0.0
45
+ resid_pdrop: float = 0.0
46
+ attn_pdrop: float = 0.0
47
+ mlp_bias: bool = True
48
+ attn_bias: bool = True
49
+ gelu_use_approx: bool = False
50
+ use_head_txt: bool = True
51
+ n_classes: Optional[int] = None
52
+
53
+
54
+ @dataclass
55
+ class Stage1Config:
56
+ type: str = 'vqgan'
57
+ embed_dim: int = 256
58
+ n_embed: int = 16384
59
+ hparams: Stage1Hparams = Stage1Hparams()
60
+
61
+
62
+ @dataclass
63
+ class Stage2Config:
64
+ type: str = 'transformer1d'
65
+ vocab_size_txt: int = 16384
66
+ vocab_size_img: int = 16384
67
+ use_cls_cond: Optional[bool] = None
68
+ hparams: Stage2Hparams = Stage2Hparams()
69
+
70
+
71
+ @dataclass
72
+ class WarmupConfig:
73
+ epoch: int = 1
74
+ multiplier: int = 1
75
+ buffer_epoch: int = 0
76
+ min_lr: float = 0.0
77
+ mode: str = 'fix'
78
+ peak_lr: float = 1e-4
79
+ start_from_zero: bool = True
80
+
81
+
82
+ @dataclass
83
+ class OptConfig:
84
+ opt_type: str = 'adamW'
85
+ learning_rate: float = 5e-5
86
+ weight_decay: float = 1e-4
87
+ betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
88
+ grad_clip_norm: float = 1.0
89
+
90
+ sched_type: str = 'cosine'
91
+ max_steps: int = 0
92
+ min_lr: float = 1e-6
93
+
94
+
95
+ @dataclass
96
+ class ExpConfig:
97
+ per_gpu_train_batch_size: int = 4
98
+ per_gpu_eval_batch_size: int = 32
99
+ num_train_epochs: int = 10
100
+ save_ckpt_freq: int = 1
101
+ test_freq: int = 10
102
+ use_amp: bool = True
103
+
104
+
105
+ @dataclass
106
+ class PrefixModelConfig:
107
+ model_name_or_path: Optional[str] = ''
108
+ prefix_model_name_or_path: str = ''
109
+ prefix_mode: str = 'activation'
110
+ tuning_mode: str = 'finetune'
111
+ top_k_layers: int = 2
112
+ parameterize_mode: str = 'mlp'
113
+ optim_prefix: bool = False
114
+ preseqlen: int = 10
115
+ prefix_dropout: float = 0.1
116
+ init_random: bool = False
117
+ hidden_dim_prefix: int = 512
118
+ lowdata: bool = False
119
+ lowdata_token: str = ''
120
+ init_shallow: bool = False
121
+ init_shallow_word: bool = False
122
+ teacher_dropout: float = 0.1
123
+ gumbel: bool = False
124
+ replay_buffer: bool = False
125
+
126
+
127
+ @dataclass
128
+ class PromptModelConfig:
129
+ model_name_or_path: Optional[str] = ''
130
+ prefix_model_name_or_path: str = ''
131
+ tuning_mode: str = 'prompt'
132
+ preseqlen: int = 10
133
+ prefix_dropout: float = 0.1
134
+
135
+
136
+ @dataclass
137
+ class StoryModelConfig:
138
+ model_name_or_path: Optional[str] = ''
139
+ prefix_model_name_or_path: str = ''
140
+ tuning_mode: str = 'story'
141
+ preseqlen: int = 10
142
+ prefix_dropout: float = 0.1
143
+ prompt: bool = False
144
+ story_len: int = 4
145
+ sent_embed: int = 256
146
+ condition: bool = False
147
+ clip_embed: bool = False
148
+
149
+
150
+ @dataclass
151
+ class DefaultConfig:
152
+ dataset: DataConfig = DataConfig()
153
+ stage1: Stage1Config = Stage1Config()
154
+ stage2: Stage2Config = Stage2Config()
155
+
156
+
157
+ @dataclass
158
+ class FineTuningConfig:
159
+ dataset: DataConfig = DataConfig()
160
+ stage1: Stage1Config = Stage1Config()
161
+ stage2: Stage2Config = Stage2Config()
162
+ optimizer: OptConfig = OptConfig()
163
+ experiment: ExpConfig = ExpConfig()
164
+
165
+
166
+ @dataclass
167
+ class PrefixTuningConfig:
168
+ dataset: DataConfig = DataConfig()
169
+ stage1: Stage1Config = Stage1Config()
170
+ stage2: Stage2Config = Stage2Config()
171
+ prefix: PrefixModelConfig = PrefixModelConfig()
172
+ optimizer: OptConfig = OptConfig()
173
+ experiment: ExpConfig = ExpConfig()
174
+
175
+
176
+ @dataclass
177
+ class PromptTuningConfig:
178
+ dataset: DataConfig = DataConfig()
179
+ stage1: Stage1Config = Stage1Config()
180
+ stage2: Stage2Config = Stage2Config()
181
+ prompt: PromptModelConfig = PromptModelConfig()
182
+ optimizer: OptConfig = OptConfig()
183
+ experiment: ExpConfig = ExpConfig()
184
+
185
+
186
+ @dataclass
187
+ class StoryConfig:
188
+ dataset: DataConfig = DataConfig()
189
+ stage1: Stage1Config = Stage1Config()
190
+ stage2: Stage2Config = Stage2Config()
191
+ story: StoryModelConfig = StoryModelConfig()
192
+ optimizer: OptConfig = OptConfig()
193
+ experiment: ExpConfig = ExpConfig()
194
+
195
+
196
+ def get_base_config(mode):
197
+ if mode == 'default':
198
+ return OmegaConf.structured(DefaultConfig)
199
+ elif mode == 'finetuning':
200
+ return OmegaConf.structured(FineTuningConfig)
201
+ elif mode == 'prefixtuning':
202
+ return OmegaConf.structured(PrefixTuningConfig)
203
+ elif mode == 'prompt_tuning':
204
+ return OmegaConf.structured(PromptTuningConfig)
205
+ elif mode == 'story':
206
+ return OmegaConf.structured(StoryConfig)
207
+ else:
208
+ raise ValueError
209
+ # return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)
dalle/utils/sampling.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import torch
8
+ from typing import Optional
9
+ from tqdm import tqdm
10
+ from torch.nn import functional as F
11
+
12
+
13
+ torch.set_printoptions(precision=2, threshold=10)
14
+ def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
15
+ if k is None:
16
+ return logits
17
+ else:
18
+ v, ix = torch.topk(logits, k)
19
+ out = logits.clone()
20
+ out[out < v[:, [-1]]] = -float('Inf')
21
+ return out
22
+
23
+
24
+ def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
25
+ if p is None:
26
+ return probs
27
+ else:
28
+ sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
29
+ cum_probs = torch.cumsum(sorted_probs, dim=-1)
30
+
31
+ sorted_idx_remove_cond = cum_probs >= p
32
+
33
+ sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
34
+ sorted_idx_remove_cond[..., 0] = 0
35
+
36
+ indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
37
+ probs = probs.masked_fill(indices_to_remove, 0.0)
38
+ norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
39
+ return norm_probs
40
+
41
+
42
+ def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor:
43
+ device = inputs.device
44
+ if mode == '1d':
45
+ B, N = inputs.shape
46
+ xs_pos = torch.arange(N, device=device).repeat((B, 1))
47
+ elif mode == '2d':
48
+ B, H, W = inputs.shape
49
+ xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
50
+ xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
51
+ xs_pos = (xs_pos_h, xs_pos_w)
52
+ else:
53
+ raise ValueError('%s positional encoding invalid' % mode)
54
+ return xs_pos
55
+
56
+
57
+ @torch.no_grad()
58
+ def sampling(model: torch.nn.Module,
59
+ tokens: torch.LongTensor,
60
+ top_k: Optional[float] = None,
61
+ top_p: Optional[float] = None,
62
+ softmax_temperature: float = 1.0,
63
+ is_tqdm: bool = True,
64
+ use_fp16: bool = True,
65
+ max_seq_len: int = 256,
66
+ prompt: Optional[torch.tensor] = None,
67
+ pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor:
68
+
69
+ code = None
70
+ past = None
71
+
72
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
73
+ pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
74
+
75
+ for cnt, h in enumerate(pbar):
76
+ if code is None:
77
+ code_ = None
78
+ pos_enc_code_ = None
79
+ else:
80
+ code_ = code.clone().detach()
81
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
82
+ code_ = code_[:, cnt-1].unsqueeze(-1)
83
+ pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
84
+
85
+ logits, present = model.sampling(images=code_,
86
+ texts=tokens,
87
+ pos_images=pos_enc_code_,
88
+ pos_texts=pos_enc_tokens,
89
+ use_fp16=use_fp16,
90
+ past=past,
91
+ prompt=prompt,
92
+ pos_prompt=pos_prompt)
93
+
94
+ logits = logits.to(dtype=torch.float32)
95
+ logits = logits / softmax_temperature
96
+
97
+ # print(len(present), present[0].shape)
98
+ present = torch.stack(present).clone().detach()
99
+ if past is None:
100
+ past = [present]
101
+ else:
102
+ past.append(present)
103
+
104
+ logits = cutoff_topk_logits(logits, top_k)
105
+ probs = F.softmax(logits, dim=-1)
106
+ probs = cutoff_topp_probs(probs, top_p)
107
+ # print(probs[0])
108
+
109
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
110
+ # print(idx)
111
+ code = idx if code is None else torch.cat([code, idx], axis=1)
112
+
113
+ del past
114
+ return code
115
+
116
+
117
+ @torch.no_grad()
118
+ def sampling_prefix(model: torch.nn.Module,
119
+ tokens: torch.LongTensor,
120
+ past: torch.FloatTensor,
121
+ top_k: Optional[float] = None,
122
+ top_p: Optional[float] = None,
123
+ softmax_temperature: float = 1.0,
124
+ is_tqdm: bool = True,
125
+ use_fp16: bool = True,
126
+ max_seq_len: int = 256,
127
+ labels = None) -> torch.LongTensor:
128
+ code = None
129
+
130
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
131
+ pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
132
+
133
+ # print("Entering sampling_prefix; ", past.shape)
134
+ if past is not None:
135
+ past = [past]
136
+
137
+ for cnt, h in enumerate(pbar):
138
+ if code is None:
139
+ code_ = None
140
+ pos_enc_code_ = None
141
+ else:
142
+ code_ = code.clone().detach()
143
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
144
+ code_ = code_[:, cnt-1].unsqueeze(-1)
145
+ pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
146
+
147
+ # print("Looop enter")
148
+ # print(cnt, past[0].shape)
149
+ # print("-------------------")
150
+ logits, present = model.sampling(images=code_,
151
+ texts=tokens,
152
+ pos_images=pos_enc_code_,
153
+ pos_texts=pos_enc_tokens,
154
+ use_fp16=use_fp16,
155
+ past=past)
156
+ logits = logits.to(dtype=torch.float32)
157
+ logits = logits / softmax_temperature
158
+
159
+ present = torch.stack(present).clone().detach()
160
+
161
+ # print('Present', present.shape)
162
+
163
+ if past is None:
164
+ past = [present]
165
+ else:
166
+ # print("Loop end")
167
+ # print(present.shape)
168
+ # print("-----------------")
169
+
170
+ # n_layers, temp, _, seq_len, n_dim = present.shape
171
+ # _, _, bs, n_heads, pre_seq_len, n_dim = past[0].shape
172
+ # assert temp == 2
173
+ # past.append(present.view(n_layers, temp, bs, n_heads, seq_len, n_dim))
174
+
175
+ past.append(present)
176
+
177
+ logits = cutoff_topk_logits(logits, top_k)
178
+ probs = F.softmax(logits, dim=-1)
179
+ probs = cutoff_topp_probs(probs, top_p)
180
+ print(torch.topk(probs, 5, dim=-1))
181
+ if labels is not None:
182
+ print(labels[cnt])
183
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
184
+ # print(idx)
185
+ code = idx if code is None else torch.cat([code, idx], axis=1)
186
+
187
+ del past
188
+ return code
189
+
190
+
191
+ @torch.no_grad()
192
+ def sampling_prefix_new(model: torch.nn.Module,
193
+ tokens: torch.LongTensor,
194
+ past: torch.FloatTensor,
195
+ top_k: Optional[float] = None,
196
+ top_p: Optional[float] = None,
197
+ softmax_temperature: float = 1.0,
198
+ is_tqdm: bool = True,
199
+ use_fp16: bool = True,
200
+ max_seq_len: int = 256) -> torch.LongTensor:
201
+ code = None
202
+
203
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
204
+ pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
205
+
206
+ # print("Entering sampling_prefix; ", past.shape)
207
+ if past is not None:
208
+ past = [past]
209
+
210
+ for cnt, h in enumerate(pbar):
211
+ if code is None:
212
+ code_ = None
213
+ pos_enc_code_ = None
214
+ else:
215
+ code_ = code.clone().detach()
216
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
217
+ # code_ = code_[:, cnt-1].unsqueeze(-1)
218
+ # pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
219
+
220
+ # print("Looop enter")
221
+ # print(cnt, past[0].shape)
222
+ # print("-------------------")
223
+
224
+ if cnt == 0:
225
+ logits, present = model.sampling(images=code_,
226
+ texts=tokens,
227
+ pos_images=pos_enc_code_,
228
+ pos_texts=pos_enc_tokens,
229
+ use_fp16=use_fp16,
230
+ past=past)
231
+ logits = logits.to(dtype=torch.float32)
232
+ logits = logits / softmax_temperature
233
+
234
+ present = torch.stack(present).clone().detach()
235
+
236
+ # print('Present', present.shape)
237
+
238
+ if past is None:
239
+ past = [present]
240
+ else:
241
+ pass
242
+
243
+ logits = cutoff_topk_logits(logits, top_k)
244
+ probs = F.softmax(logits, dim=-1)
245
+ probs = cutoff_topp_probs(probs, top_p)
246
+ # print(torch.topk(probs[0], 5))
247
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
248
+ # print(idx)
249
+ code = idx if code is None else torch.cat([code, idx], axis=1)
250
+
251
+ else:
252
+ pass
253
+
254
+
255
+ del past
256
+ return code
257
+
258
+ @torch.no_grad()
259
+ def sampling_conditional(model: torch.nn.Module,
260
+ cross_attention_idxs,
261
+ cross_attention_layers,
262
+ tokens: torch.LongTensor,
263
+ src_codes: torch.FloatTensor,
264
+ top_k: Optional[float] = None,
265
+ top_p: Optional[float] = None,
266
+ softmax_temperature: float = 1.0,
267
+ is_tqdm: bool = True,
268
+ use_fp16: bool = True,
269
+ max_seq_len: int = 256,
270
+ prompt: Optional[torch.tensor] = None,
271
+ pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor:
272
+
273
+ code = None
274
+ past = None
275
+
276
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
277
+ pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
278
+
279
+ src_pos_tokens = get_positional_encoding(src_codes, mode='1d')
280
+ src_tokens = model.tok_emb_img(src_codes)
281
+ src_tokens = src_tokens + model.pos_emb_img(src_pos_tokens)
282
+
283
+ for cnt, h in enumerate(pbar):
284
+ if code is None:
285
+ code_ = None
286
+ pos_enc_code_ = None
287
+ else:
288
+ code_ = code.clone().detach()
289
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
290
+ code_ = code_[:, cnt-1].unsqueeze(-1)
291
+ pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
292
+
293
+ logits, present = model.sampling_with_context(images=code_,
294
+ cross_attention_idxs=cross_attention_idxs,
295
+ cross_attention_layers=cross_attention_layers,
296
+ texts=tokens,
297
+ pos_images=pos_enc_code_,
298
+ pos_texts=pos_enc_tokens,
299
+ source_image=src_tokens,
300
+ use_fp16=use_fp16,
301
+ past=past,
302
+ prompt=prompt,
303
+ pos_prompt=pos_prompt)
304
+ logits = logits.to(dtype=torch.float32)
305
+ logits = logits / softmax_temperature
306
+
307
+ present = torch.stack(present).clone().detach()
308
+ if past is None:
309
+ past = [present]
310
+ else:
311
+ past.append(present)
312
+
313
+ logits = cutoff_topk_logits(logits, top_k)
314
+ probs = F.softmax(logits, dim=-1)
315
+ probs = cutoff_topp_probs(probs, top_p)
316
+
317
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
318
+ code = idx if code is None else torch.cat([code, idx], axis=1)
319
+
320
+ del past
321
+ return code
322
+
323
+
324
+ @torch.no_grad()
325
+ def sampling_igpt(model: torch.nn.Module,
326
+ sos: torch.FloatTensor,
327
+ top_k: Optional[float] = None,
328
+ top_p: Optional[float] = None,
329
+ softmax_temperature: float = 1.0,
330
+ is_tqdm: bool = True,
331
+ use_fp16: bool = True,
332
+ max_seq_len: int = 256) -> torch.LongTensor:
333
+ code = None
334
+ past = None
335
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
336
+
337
+ for cnt, h in enumerate(pbar):
338
+ if code is None:
339
+ code_ = None
340
+ pos_enc_code_ = None
341
+ else:
342
+ code_ = code.clone().detach()
343
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
344
+ code_ = code_[:, cnt-1].unsqueeze(-1)
345
+ pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
346
+
347
+ logits, present = model.sampling(sos=sos,
348
+ codes=code_,
349
+ pos_codes=pos_enc_code_,
350
+ use_fp16=use_fp16,
351
+ past=past)
352
+ logits = logits.to(dtype=torch.float32)
353
+ logits = logits / softmax_temperature
354
+
355
+ present = torch.stack(present).clone().detach()
356
+ if past is None:
357
+ past = [present]
358
+ else:
359
+ past.append(present)
360
+
361
+ logits = cutoff_topk_logits(logits, top_k)
362
+ probs = F.softmax(logits, dim=-1)
363
+ probs = cutoff_topp_probs(probs, top_p)
364
+
365
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
366
+ code = idx if code is None else torch.cat([code, idx], axis=1)
367
+
368
+ del past
369
+ return code
dalle/utils/utils.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import random
9
+ import urllib
10
+ import hashlib
11
+ import tarfile
12
+ import torch
13
+ import clip
14
+ import numpy as np
15
+ from PIL import Image
16
+ from torch.nn import functional as F
17
+ from tqdm import tqdm
18
+ import torchvision.utils as vutils
19
+ import matplotlib.pyplot as plt
20
+
21
+
22
+ def set_seed(seed: int):
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+ torch.cuda.manual_seed_all(seed)
27
+
28
+
29
+ @torch.no_grad()
30
+ def clip_score(prompt: str,
31
+ images: np.ndarray,
32
+ model_clip: torch.nn.Module,
33
+ preprocess_clip,
34
+ device: str) -> np.ndarray:
35
+ images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
36
+ images = torch.stack(images, dim=0).to(device=device)
37
+ texts = clip.tokenize(prompt).to(device=device)
38
+ texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
39
+
40
+ image_features = model_clip.encode_image(images)
41
+ text_features = model_clip.encode_text(texts)
42
+
43
+ scores = F.cosine_similarity(image_features, text_features).squeeze()
44
+ rank = torch.argsort(scores, descending=True).cpu().numpy()
45
+ return rank
46
+
47
+
48
+ def download(url: str, root: str) -> str:
49
+ os.makedirs(root, exist_ok=True)
50
+ filename = os.path.basename(url)
51
+ pathname = filename[:-len('.tar.gz')]
52
+
53
+ expected_md5 = url.split("/")[-2]
54
+ download_target = os.path.join(root, filename)
55
+ result_path = os.path.join(root, pathname)
56
+
57
+ if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
58
+ return result_path
59
+
60
+ with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
61
+ with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
62
+ unit_divisor=1024) as loop:
63
+ while True:
64
+ buffer = source.read(8192)
65
+ if not buffer:
66
+ break
67
+
68
+ output.write(buffer)
69
+ loop.update(len(buffer))
70
+
71
+ if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
72
+ raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')
73
+
74
+ with tarfile.open(download_target, 'r:gz') as f:
75
+ pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
76
+ for member in pbar:
77
+ pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
78
+ f.extract(member=member, path=root)
79
+
80
+ return result_path
81
+
82
+
83
+ def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
84
+ if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
85
+ return download(url_or_path, root)
86
+ return url_or_path
87
+
88
+
89
+ def images_to_numpy(tensor):
90
+ generated = tensor.data.cpu().numpy().transpose(1,2,0)
91
+ generated[generated < -1] = -1
92
+ generated[generated > 1] = 1
93
+ generated = (generated + 1) / 2 * 255
94
+ return generated.astype('uint8')
95
+
96
+
97
+ def save_image(ground_truth, images, out_dir, batch_idx):
98
+
99
+ for i, im in enumerate(images):
100
+ if len(im.shape) == 3:
101
+ plt.imsave(os.path.join(out_dir, 'test_%s_%s.png' % (batch_idx, i)), im)
102
+ else:
103
+ bs = im.shape[0]
104
+ # plt.imsave()
105
+ for j in range(bs):
106
+ plt.imsave(os.path.join(out_dir, 'test_%s_%s_%s.png' % (batch_idx, i, j)), im[j])
107
+
108
+
109
+ # print("Ground truth Images shape: ", ground_truth.shape, len(images))
110
+
111
+ # images = vutils.make_grid(images, nrow=ground_truth.shape[0])
112
+ # images = images_to_numpy(images)
113
+ #
114
+ # if ground_truth is not None:
115
+ # ground_truth = vutils.make_grid(ground_truth, 5)
116
+ # ground_truth = images_to_numpy(ground_truth)
117
+ # print("Ground Truth shape, Generated Images shape: ", ground_truth.shape, images.shape)
118
+ # images = np.concatenate([ground_truth, images], axis=0)
119
+ #
120
+ # output = Image.fromarray(images)
121
+ # output.save('%s/fake_samples_epoch_%03d.png' % (out_dir, batch_idx))
122
+
123
+ # if texts is not None:
124
+ # fid = open('%s/fake_samples_epoch_%03d_%03d.txt' % (image_dir, epoch, idx), 'w')
125
+ # for idx in range(images.shape[0]):
126
+ # fid.write(str(idx) + '--------------------------------------------------------\n')
127
+ # for i in range(len(texts)):
128
+ # fid.write(texts[i][idx] + '\n')
129
+ # fid.write('\n\n')
130
+ # fid.close()
131
+ return
demo/Barney.png ADDED
demo/Betty.png ADDED
demo/Crong.png ADDED
demo/Dino.png ADDED
demo/Eddy.png ADDED
demo/Fred.png ADDED
demo/Harry.png ADDED
demo/Loopy.png ADDED
demo/MrSlate.png ADDED
demo/Pebbles.png ADDED
demo/Petty.png ADDED
demo/Poby.png ADDED
demo/Pororo.png ADDED
demo/Rody.png ADDED
demo/Tongtong.png ADDED