boris commited on
Commit
39caefb
2 Parent(s): 1d04ab3 a09ea25

Merge branch 'main' of https://github.com/borisdayma/dalle-mini into fix-opt_state

Browse files
.github/workflows/sync_to_hub_debug.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy to debug app
2
+
3
+ on:
4
+ # to run this workflow manually from the Actions tab
5
+ workflow_dispatch:
6
+
7
+ jobs:
8
+ sync-to-hub-debug:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v2
12
+ with:
13
+ fetch-depth: 0
14
+ - name: Push to hub
15
+ env:
16
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
17
+ run: git push --force https://boris:[email protected]/spaces/flax-community/dalle-mini-debug +HEAD:main
CITATION.cff ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YAML 1.2
2
+ ---
3
+ abstract: "DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardware resources. By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models, we were able to create a model that is 27 times smaller than the original DALL·E and train it on a single TPU v3-8 for only 3 days. DALL·E mini achieves impressive results, albeit of a lower quality than the original system. It can be used for exploration and further experimentation on commodity hardware."
4
+ authors:
5
+ -
6
+ family-names: Dayma
7
+ given-names: Boris
8
+ -
9
+ family-names: Patil
10
+ given-names: Suraj
11
+ -
12
+ family-names: Cuenca
13
+ given-names: Pedro
14
+ -
15
+ family-names: Saifullah
16
+ given-names: Khalid
17
+ -
18
+ family-names: Abraham
19
+ given-names: Tanishq
20
+ -
21
+ family-names: "Lê Khắc"
22
+ given-names: "Phúc"
23
+ -
24
+ family-names: Melas
25
+ given-names: Luke
26
+ -
27
+ family-names: Ghosh
28
+ given-names: Ritobrata
29
+ cff-version: "1.1.0"
30
+ date-released: 2021-07-29
31
+ identifiers:
32
+ keywords:
33
+ - dalle
34
+ - "text-to-image generation"
35
+ - transformer
36
+ - "zero-shot"
37
+ - JAX
38
+ license: "Apache-2.0"
39
+ doi: 10.5281/zenodo.5146400
40
+ message: "If you use this project, please cite it using these metadata."
41
+ repository-code: "https://github.com/borisdayma/dalle-mini"
42
+ title: "DALL·E Mini"
43
+ version: "v0.1-alpha"
44
+ ...
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: DALL·E mini
3
  emoji: 🥑
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: streamlit
7
  app_file: app/app.py
8
  pinned: false
@@ -16,7 +16,7 @@ _Generate images from a text prompt_
16
 
17
  Our logo was generated with DALL·E mini using the prompt "logo of an armchair in the shape of an avocado".
18
 
19
- You can create your own pictures with [the demo](https://huggingface.co/spaces/flax-community/dalle-mini) (temporarily in beta on Huging Face Spaces but soon to be open to all).
20
 
21
  ## How does it work?
22
 
@@ -26,8 +26,6 @@ Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini
26
 
27
  ### Dependencies Installation
28
 
29
- The root folder and associated [`requirements.txt`](./requirements.txt) is only for the app.
30
-
31
  For development, use [`dev/requirements.txt`](dev/requirements.txt) or [`dev/environment.yaml`](dev/environment.yaml).
32
 
33
  ### Training of VQGAN
@@ -52,7 +50,16 @@ To generate sample predictions and understand the inference pipeline step by ste
52
 
53
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/dev/inference/inference_pipeline.ipynb)
54
 
55
- ## Where does the logo come from?
 
 
 
 
 
 
 
 
 
56
 
57
  The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
58
 
@@ -70,4 +77,66 @@ The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL
70
  ## Acknowledgements
71
 
72
  - 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
73
- - Google Cloud team for providing access to TPU's
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: DALL·E mini
3
  emoji: 🥑
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: streamlit
7
  app_file: app/app.py
8
  pinned: false
 
16
 
17
  Our logo was generated with DALL·E mini using the prompt "logo of an armchair in the shape of an avocado".
18
 
19
+ You can create your own pictures with [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).
20
 
21
  ## How does it work?
22
 
 
26
 
27
  ### Dependencies Installation
28
 
 
 
29
  For development, use [`dev/requirements.txt`](dev/requirements.txt) or [`dev/environment.yaml`](dev/environment.yaml).
30
 
31
  ### Training of VQGAN
 
50
 
51
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/dev/inference/inference_pipeline.ipynb)
52
 
53
+ ## FAQ
54
+
55
+ ### Where to find the latest models?
56
+
57
+ Trained models are on 🤗 Model Hub:
58
+
59
+ - [VQGAN-f16-16384](https://huggingface.co/flax-community/vqgan_f16_16384) for encoding/decoding images
60
+ - [DALL·E mini](https://huggingface.co/flax-community/dalle-mini) for generating images from a text prompt
61
+
62
+ ### Where does the logo come from?
63
 
64
  The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
65
 
 
77
  ## Acknowledgements
78
 
79
  - 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
80
+ - Google [TPU Research Cloud (TRC) program](https://sites.research.google/trc/) for providing computing resources
81
+ - [Weights & Biases](https://wandb.com/) for providing the infrastructure for experiment tracking and model management
82
+
83
+ ## Citing DALL·E mini
84
+
85
+ If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
86
+
87
+ ```
88
+ @misc{Dayma_DALL·E_Mini_2021,
89
+ author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
90
+ doi = {10.5281/zenodo.5146400},
91
+ month = {7},
92
+ title = {DALL·E Mini},
93
+ url = {https://github.com/borisdayma/dalle-mini},
94
+ year = {2021}
95
+ }
96
+ ```
97
+
98
+ ## References
99
+
100
+ ```
101
+ @misc{ramesh2021zeroshot,
102
+ title={Zero-Shot Text-to-Image Generation},
103
+ author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
104
+ year={2021},
105
+ eprint={2102.12092},
106
+ archivePrefix={arXiv},
107
+ primaryClass={cs.CV}
108
+ }
109
+ ```
110
+
111
+ ```
112
+ @misc{esser2021taming,
113
+ title={Taming Transformers for High-Resolution Image Synthesis},
114
+ author={Patrick Esser and Robin Rombach and Björn Ommer},
115
+ year={2021},
116
+ eprint={2012.09841},
117
+ archivePrefix={arXiv},
118
+ primaryClass={cs.CV}
119
+ }
120
+ ```
121
+
122
+ ```
123
+ @misc{lewis2019bart,
124
+ title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
125
+ author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
126
+ year={2019},
127
+ eprint={1910.13461},
128
+ archivePrefix={arXiv},
129
+ primaryClass={cs.CL}
130
+ }
131
+ ```
132
+
133
+ ```
134
+ @misc{radford2021learning,
135
+ title={Learning Transferable Visual Models From Natural Language Supervision},
136
+ author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
137
+ year={2021},
138
+ eprint={2103.00020},
139
+ archivePrefix={arXiv},
140
+ primaryClass={cs.CV}
141
+ }
142
+ ```
app/app.py CHANGED
@@ -1,7 +1,6 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
4
- import random
5
  from dalle_mini.backend import ServiceError, get_images_from_backend
6
 
7
  import streamlit as st
@@ -55,12 +54,31 @@ st.subheader('Generate images from text')
55
 
56
  prompt = st.text_input("What do you want to see?")
57
 
58
- #TODO: I think there's an issue where we can't run twice the same inference (not due to caching) - may need to use st.form
59
-
60
  DEBUG = False
61
  if prompt != "" or (should_run_again and prompt != ""):
62
  container = st.empty()
63
- container.markdown(f"Generating predictions for: **{prompt}**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  try:
66
  backend_url = st.secrets["BACKEND_SERVER"]
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
 
4
  from dalle_mini.backend import ServiceError, get_images_from_backend
5
 
6
  import streamlit as st
 
54
 
55
  prompt = st.text_input("What do you want to see?")
56
 
57
+ test = st.empty()
 
58
  DEBUG = False
59
  if prompt != "" or (should_run_again and prompt != ""):
60
  container = st.empty()
61
+ # The following mimics `streamlit.info()`.
62
+ # I tried to get the secondary background color using `components.streamlit.config.get_options_for_section("theme")["secondaryBackgroundColor"]`
63
+ # but it returns None.
64
+ container.markdown(f"""
65
+ <style> p {{ margin:0 }} div {{ margin:0 }} </style>
66
+ <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
67
+ <div class="stAlert">
68
+ <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
69
+ <div class="st-b7">
70
+ <div class="css-whx05o e13vu3m50">
71
+ <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
72
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/img/loading.gif" width="30"/>
73
+ Generating predictions for: <b>{prompt}</b>
74
+ </div>
75
+ </div>
76
+ </div>
77
+ </div>
78
+ </div>
79
+ </div>
80
+ <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
81
+ """, unsafe_allow_html=True)
82
 
83
  try:
84
  backend_url = st.secrets["BACKEND_SERVER"]
app/dalle_mini DELETED
@@ -1 +0,0 @@
1
- ../dalle_mini/
 
 
app/gradio/dalle_mini DELETED
@@ -1 +0,0 @@
1
- ../../dalle_mini/
 
 
app/img/loading.gif ADDED
dalle_mini/text.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for processing text.
3
+ """
4
+
5
+ import requests
6
+ from pathlib import Path
7
+ from unidecode import unidecode
8
+
9
+ import re, math, random, html
10
+ import ftfy
11
+
12
+ WIKI_STATS_URL = "https://github.com/borisdayma/wikipedia-word-frequency/raw/feat-update/results/enwiki-20210820-words-frequency.txt"
13
+ WIKI_STATS_LOCAL = Path(WIKI_STATS_URL).parts[-1]
14
+
15
+ # based on wiki word occurence
16
+ person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
17
+ temp_token = "xtokx" # avoid repeating chars
18
+
19
+
20
+ def get_wiki_file():
21
+ if not Path(WIKI_STATS_LOCAL).exists():
22
+ r = requests.get(WIKI_STATS_URL, stream=True)
23
+ with open(WIKI_STATS_LOCAL, "wb") as fd:
24
+ for chunk in r.iter_content(chunk_size=128):
25
+ fd.write(chunk)
26
+ return WIKI_STATS_LOCAL
27
+
28
+
29
+ class HashtagProcessor:
30
+ # Adapted from wordninja library
31
+ # We use our wikipedia word count + a good heuristic to make it work
32
+ def __init__(self):
33
+ self._word_cost = (
34
+ l.split()[0] for l in Path(get_wiki_file()).read_text().splitlines()
35
+ )
36
+ self._word_cost = {
37
+ str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
38
+ }
39
+ self._max_word = max(len(x) for x in self._word_cost.keys())
40
+ self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
41
+
42
+ def __call__(self, s):
43
+ """Uses dynamic programming to infer the location of spaces in a string without spaces."""
44
+ l = [self._split(x) for x in self._SPLIT_RE.split(s)]
45
+ return " ".join([item for sublist in l for item in sublist])
46
+
47
+ def _split(self, s):
48
+ # Find the best match for the i first characters, assuming cost has
49
+ # been built for the i-1 first characters.
50
+ # Returns a pair (match_cost, match_length).
51
+ def best_match(i):
52
+ candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
53
+ return min(
54
+ (c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
55
+ for k, c in candidates
56
+ )
57
+
58
+ # Build the cost array
59
+ cost = [0]
60
+ for i in range(1, len(s) + 1):
61
+ c, k = best_match(i)
62
+ cost.append(c)
63
+
64
+ # Backtrack to recover the minimal-cost string.
65
+ out = []
66
+ i = len(s)
67
+ while i > 0:
68
+ c, k = best_match(i)
69
+ assert c == cost[i]
70
+ newToken = True
71
+ if not s[i - k : i] == "'": # ignore a lone apostrophe
72
+ if len(out) > 0:
73
+ # re-attach split 's and split digits
74
+ if out[-1] == "'s" or (
75
+ s[i - 1].isdigit() and out[-1][0].isdigit()
76
+ ): # digit followed by digit
77
+ out[-1] = (
78
+ s[i - k : i] + out[-1]
79
+ ) # combine current token with previous token
80
+ newToken = False
81
+
82
+ if newToken:
83
+ out.append(s[i - k : i])
84
+
85
+ i -= k
86
+
87
+ return reversed(out)
88
+
89
+
90
+ def replace_person_token(t):
91
+ "Used for CC12M"
92
+ t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
93
+ while "<person>" in t:
94
+ t = t.replace(
95
+ "<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
96
+ )
97
+ return t
98
+
99
+
100
+ def fix_html(t):
101
+ "Adapted from fastai"
102
+ t = (
103
+ t.replace("#39;", "'")
104
+ .replace("&amp;", "&")
105
+ .replace("amp;", "&")
106
+ .replace("#146;", "'")
107
+ .replace("nbsp;", " ")
108
+ .replace("#36;", "$")
109
+ .replace("\\n", "\n")
110
+ .replace("quot;", "'")
111
+ .replace("<br />", "\n")
112
+ .replace('\\"', '"')
113
+ .replace("<unk>", " ")
114
+ .replace(" @.@ ", ".")
115
+ .replace(" @-@ ", "-")
116
+ )
117
+ return html.unescape(t)
118
+
119
+
120
+ def replace_punctuation_with_commas(t):
121
+ return re.sub("""([()[\].,|:;?!=+~\-])""", ",", t)
122
+
123
+
124
+ def simplify_quotes(t):
125
+ return re.sub("""['"`]""", ' " ', t)
126
+
127
+
128
+ def merge_quotes(t):
129
+ return re.sub('(\s*"+\s*)+', ' " ', t)
130
+
131
+
132
+ def remove_comma_numbers(t):
133
+ def _f(t):
134
+ return re.sub("(\d),(\d{3})", r"\1\2", t)
135
+
136
+ return _f(_f(t))
137
+
138
+
139
+ def pre_process_dot_numbers(t):
140
+ return re.sub("(\d)\.(\d)", fr"\1{temp_token}dot{temp_token}\2", t)
141
+
142
+
143
+ def post_process_dot_numbers(t):
144
+ return re.sub(f"{temp_token}dot{temp_token}", ".", t)
145
+
146
+
147
+ def pre_process_quotes(t):
148
+ # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
149
+ return re.sub(
150
+ r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", fr"{temp_token}quote{temp_token}", t
151
+ )
152
+
153
+
154
+ def post_process_quotes(t):
155
+ return re.sub(f"{temp_token}quote{temp_token}", "'", t)
156
+
157
+
158
+ def merge_commas(t):
159
+ return re.sub("(\s*,+\s*)+", ", ", t)
160
+
161
+
162
+ def add_space_after_commas(t):
163
+ return re.sub(",", ", ", t)
164
+
165
+
166
+ def handle_special_chars(t):
167
+ "Handle special characters"
168
+ # replace "-" with a space when between words without space
169
+ t = re.sub("([a-zA-Z])-([a-zA-Z])", r"\1 \2", t)
170
+ # always add space around &
171
+ return re.sub("&", " & ", t)
172
+
173
+
174
+ def expand_hashtags(t, hashtag_processor):
175
+ "Remove # and try to split words"
176
+ return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
177
+
178
+
179
+ _re_ignore_chars = """[_#\/\\%]"""
180
+
181
+
182
+ def ignore_chars(t):
183
+ "Ignore useless characters"
184
+ return re.sub(_re_ignore_chars, " ", t)
185
+
186
+
187
+ def remove_extra_spaces(t):
188
+ "Remove extra spaces (including \t and \n)"
189
+ return re.sub("\s+", " ", t)
190
+
191
+
192
+ def remove_repeating_chars(t):
193
+ "If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
194
+ return re.sub(r"(\D)(\1{3,})", r"\1", t)
195
+
196
+
197
+ def remove_urls(t):
198
+ return re.sub(r"http\S+", "", t)
199
+
200
+
201
+ def remove_html_tags(t):
202
+ return re.sub("<[^<]+?>", "", t)
203
+
204
+
205
+ def remove_first_last_commas(t):
206
+ t = t.strip()
207
+ t = t[:-1] if t and t[-1] == "," else t
208
+ t = t[1:] if t and t[0] == "," else t
209
+ return t.strip()
210
+
211
+
212
+ def remove_wiki_ref(t):
213
+ t = re.sub(r"\A\s*\[\d+\]", "", t)
214
+ return re.sub(r"\[\d+\]\s*\Z", "", t)
215
+
216
+
217
+ class TextNormalizer:
218
+ "Normalize text"
219
+
220
+ def __init__(self):
221
+ self._hashtag_processor = HashtagProcessor()
222
+
223
+ def __call__(self, t, clip=False):
224
+
225
+ # fix some characters
226
+ t = ftfy.fix_text(t)
227
+ # fix html
228
+ t = fix_html(t)
229
+ if not clip:
230
+ # decode and simplify text: see unidecode library
231
+ t = unidecode(t)
232
+ # lower case
233
+ t = t.lower()
234
+ # replace <PERSON> (for CC12M)
235
+ t = replace_person_token(t)
236
+ # remove wiki reference (for WIT)
237
+ t = remove_wiki_ref(t)
238
+ # remove html tags
239
+ t = remove_html_tags(t)
240
+ # remove urls
241
+ t = remove_urls(t)
242
+ # remove commas in numbers
243
+ t = remove_comma_numbers(t)
244
+ if not clip:
245
+ # handle dots in numbers and quotes - Part 1
246
+ t = pre_process_dot_numbers(t)
247
+ t = pre_process_quotes(t)
248
+ # handle special characters
249
+ t = handle_special_chars(t)
250
+ # handle hashtags
251
+ t = expand_hashtags(t, self._hashtag_processor)
252
+ # ignore useless characters
253
+ t = ignore_chars(t)
254
+ # simplify quotes
255
+ t = simplify_quotes(t)
256
+ # all punctuation becomes commas
257
+ t = replace_punctuation_with_commas(t)
258
+ # handle dots in numbers and quotes - Part 2
259
+ t = post_process_dot_numbers(t)
260
+ t = post_process_quotes(t)
261
+ # handle repeating characters
262
+ t = remove_repeating_chars(t)
263
+ # merge commas
264
+ t = merge_commas(t)
265
+ # merge quotes
266
+ t = merge_quotes(t)
267
+ # remove multiple spaces
268
+ t = remove_extra_spaces(t)
269
+ # remove first and last comma
270
+ t = remove_first_last_commas(t)
271
+ # always start with a space
272
+ return f" {t}" if not clip else t
dev/README.md ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Development Instructions for TPU
2
+
3
+ ## Setup
4
+
5
+ - Apply to the [TRC program](https://sites.research.google/trc/) for free TPU credits if you're elligible.
6
+ - Follow the [Cloud TPU VM User's Guide](https://cloud.google.com/tpu/docs/users-guide-tpu-vm) to set up gcloud.
7
+ - Verify `gcloud config list`, in particular account, project & zone.
8
+ - Create a TPU VM per the guide and connect to it.
9
+
10
+ When needing a larger disk:
11
+
12
+ - Create a balanced persistent disk (SSD, so pricier than default HDD but much faster): `gcloud compute disks create DISK_NAME --size SIZE_IN_GB --type pd-balanced`
13
+ - Attach the disk to your instance by adding `--data-disk source=REF` per ["Adding a persistent disk to a TPU VM" guide](https://cloud.google.com/tpu/docs/setup-persistent-disk), eg `gcloud alpha compute tpus tpu-vm create INSTANCE_NAME --accelerator-type=v3-8 --version=v2-alpha --data-disk source=projects/tpu-toys/zones/europe-west4-a/disks/DISK_NAME`
14
+ - Format the partition as described in the guide.
15
+ - Make sure to set up automatic remount of disk at restart.
16
+
17
+ ## Connect VS Code
18
+
19
+ - Find external IP in the UI or with `gcloud alpha compute tpus tpu-vm describe INSTANCE_NAME`
20
+ - Verify you can connect in terminal with `ssh EXTERNAL_IP -i ~/.ssh/google_compute_engine`
21
+ - Add the same command as ssh host in VS Code.
22
+ - Check config file
23
+
24
+ ```
25
+ Host INSTANCE_NAME
26
+ HostName EXTERNAL_IP
27
+ IdentityFile ~/.ssh/google_compute_engine
28
+ ```
29
+
30
+ ## Environment configuration
31
+
32
+ ### Use virtual environments (optional)
33
+
34
+ We recommend using virtual environments (such as conda, venv or pyenv-virtualenv).
35
+
36
+ If you want to use `pyenv` and `pyenv-virtualenv`:
37
+
38
+ - Installation
39
+
40
+ - [Set up build environment](https://github.com/pyenv/pyenv/wiki#suggested-build-environment)
41
+ - Use [pyenv-installer](https://github.com/pyenv/pyenv-installer): `curl https://pyenv.run | bash`
42
+ - bash set-up:
43
+
44
+ ```bash
45
+ echo '\n'\
46
+ '# pyenv setup \n'\
47
+ 'export PYENV_ROOT="$HOME/.pyenv" \n'\
48
+ 'export PATH="$PYENV_ROOT/bin:$PATH" \n'\
49
+ 'eval "$(pyenv init --path)" \n'\
50
+ 'eval "$(pyenv init -)" \n'\
51
+ 'eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc
52
+ ```
53
+
54
+ - Usage
55
+
56
+ - Install a python version: `pyenv install X.X.X`
57
+ - Create a virtual environment: `pyenv virtualenv 3.9.6 dalle_env`
58
+ - Activate: `pyenv activate dalle_env`
59
+
60
+ Note: you can auto-activate your environment at a location with `echo dalle_env >> .python-version`
61
+
62
+ ### Tools
63
+
64
+ - Git
65
+
66
+ - `git config --global user.email "[email protected]"
67
+ - `git config --global user.name "First Last"
68
+
69
+ - Github CLI
70
+
71
+ - See [installation instructions](https://github.com/cli/cli/blob/trunk/docs/install_linux.md)
72
+ - `gh auth login`
73
+
74
+ - Direnv
75
+
76
+ - Install direnv: `sudo apt-get update && sudo apt-get install direnv`
77
+ - bash set-up:
78
+
79
+ ```bash
80
+ echo -e '\n'\
81
+ '# direnv setup \n'\
82
+ 'eval "$(direnv hook bash)" \n' >> ~/.bashrc
83
+ ```
84
+
85
+ ### Set up repo
86
+
87
+ - Clone repo: `gh repo clone borisdayma/dalle-mini`
88
+ - If using `pyenv-virtualenv`, auto-activate env: `echo dalle_env >> .python-version`
89
+
90
+ ## Environment
91
+
92
+ - Install the following (use it later to update our dev requirements.txt)
93
+
94
+ ```
95
+ requests
96
+ pillow
97
+ jupyterlab
98
+ ipywidgets
99
+
100
+ -e ../datasets[streaming]
101
+ -e ../transformers
102
+ -e ../webdataset
103
+
104
+ # JAX
105
+ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
106
+ jax[tpu]>=0.2.16
107
+ flax
108
+ ```
109
+
110
+ - `transformers-cli login`
111
+
112
+ ---
113
+
114
+ - set `HF_HOME="/mnt/disks/persist/cache/huggingface"` in `/etc/environment` and ensure you have required permissions, then restart.
115
+
116
+ ## Working with datasets or models
117
+
118
+ - Install [Git LFS](https://github.com/git-lfs/git-lfs/wiki/Installation)
119
+ - Clone a dataset without large files: `GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/.../...`
120
+ - Use a local [credential store](https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage) for caching credentials
121
+ - Track specific extentions: `git lfs track "*.ext"`
122
+ - See files tracked with LFS with `git lfs ls-files`
dev/encoding/vqgan-jax-encoding-streaming.ipynb ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0b72877",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VQGAN JAX Encoding for 🤗 Datasets in streaming mode"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "ba7b31e6",
14
+ "metadata": {},
15
+ "source": [
16
+ "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and 🤗 Datasets in streaming mode.\n",
17
+ "\n",
18
+ "This example uses our YFCC100M dataset, but it should be easy to adapt to any other image/caption dataset in the huggingface hub."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "3b59489e",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import io\n",
29
+ "\n",
30
+ "import requests\n",
31
+ "from PIL import Image\n",
32
+ "import numpy as np\n",
33
+ "from tqdm import tqdm\n",
34
+ "\n",
35
+ "import torch\n",
36
+ "import torchvision.transforms as T\n",
37
+ "import torchvision.transforms.functional as TF\n",
38
+ "from torchvision.transforms import InterpolationMode\n",
39
+ "import os\n",
40
+ "\n",
41
+ "import jax\n",
42
+ "from jax import pmap"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "markdown",
47
+ "id": "c7c4c1e6",
48
+ "metadata": {},
49
+ "source": [
50
+ "## Dataset and Parameters"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "id": "d45a289e",
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "import datasets\n",
61
+ "from datasets import Dataset, load_dataset"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "markdown",
66
+ "id": "f26e4f18",
67
+ "metadata": {},
68
+ "source": [
69
+ "We'll use the `validation` set for testing. Adjust accordingly."
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "id": "28893c3e",
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "dataset = load_dataset('dalle-mini/YFCC100M_OpenAI_subset', use_auth_token=True, streaming=True, split='validation')"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": null,
85
+ "id": "33861477",
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "from pathlib import Path\n",
90
+ "\n",
91
+ "yfcc100m = Path.home()/'data'/'YFCC100M_OpenAI_subset'\n",
92
+ "yfcc100m_output = yfcc100m/'encoded' # Output directory for encoded files"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "id": "6e7b71c4",
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "batch_size = 128 # Per device\n",
103
+ "num_workers = 16 # Unused in streaming mode"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "markdown",
108
+ "id": "0793c26a",
109
+ "metadata": {},
110
+ "source": [
111
+ "### Data preparation"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "markdown",
116
+ "id": "86415769",
117
+ "metadata": {},
118
+ "source": [
119
+ "* Images: we transform them so they are center-cropped and square, all of the same size so we can build batches for TPU/GPU processing.\n",
120
+ "* Captions: we extract a single `caption` column from the source data, by concatenating the cleaned title and description.\n",
121
+ "\n",
122
+ "These transformations are done using the Datasets `map` function. In the case of streaming datasets, transformations will run as needed instead of pre-processing the dataset at once."
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "id": "0fdf1851",
128
+ "metadata": {},
129
+ "source": [
130
+ "This helper function is used to decode images from the bytes retrieved in `streaming` mode."
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "id": "5bbca804",
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "from PIL import Image\n",
141
+ "import io\n",
142
+ "\n",
143
+ "def get_image(byte_stream):\n",
144
+ " image = Image.open(io.BytesIO(byte_stream))\n",
145
+ " return image.convert('RGB')"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "id": "b435290b",
151
+ "metadata": {},
152
+ "source": [
153
+ "Image processing"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "id": "7e73dfa3",
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "def center_crop(image, max_size=256):\n",
164
+ " # Note: we allow upscaling too. We should exclude small images. \n",
165
+ " image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
166
+ " image = TF.center_crop(image, output_size=2 * [max_size])\n",
167
+ " return image\n",
168
+ "\n",
169
+ "preprocess_image = T.Compose([\n",
170
+ " get_image,\n",
171
+ " center_crop,\n",
172
+ " T.ToTensor(),\n",
173
+ " lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n",
174
+ "])"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "markdown",
179
+ "id": "1e3ac8de",
180
+ "metadata": {},
181
+ "source": [
182
+ "Caption preparation"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "id": "aadb4d23",
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": [
192
+ "import string\n",
193
+ "\n",
194
+ "def create_caption(title, description):\n",
195
+ " title = title.strip()\n",
196
+ " description = description.strip()\n",
197
+ " if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
198
+ " return f'{title} {description}'"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "markdown",
203
+ "id": "3c4522b9",
204
+ "metadata": {},
205
+ "source": [
206
+ "And this is the basic transformation function to use in `map`. We don't really need the `key`, but we'll keep it for reference. Since we are returning a new dictionary (as opposed to adding entries to the input), this also removes any metadata columns we don't need."
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "id": "2566ff68",
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "def prepare_item(item):\n",
217
+ " return {\n",
218
+ " 'key': item['key'],\n",
219
+ " 'caption': create_caption(item['title_clean'], item['description_clean']),\n",
220
+ " 'image': preprocess_image(item['img'])\n",
221
+ " }"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "markdown",
226
+ "id": "e519e475",
227
+ "metadata": {},
228
+ "source": [
229
+ "Unlike when using non-streaming datasets, the following operation completes immediately in streaming mode. In streaming mode, `num_proc` is not supported."
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "id": "10d7750e",
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "prepared_dataset = dataset.map(prepare_item, batched=False)"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": null,
245
+ "id": "a8595539",
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "%%time\n",
250
+ "item = next(iter(prepared_dataset))"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "id": "04a6eeb4",
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": [
260
+ "assert(list(item.keys()) == ['key', 'caption', 'image'])"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": null,
266
+ "id": "40d3115f",
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "item['image'].shape"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": null,
276
+ "id": "dd844e1c",
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "T.ToPILImage()(item['image'].permute(2, 0, 1))"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "markdown",
285
+ "id": "44d50a51",
286
+ "metadata": {},
287
+ "source": [
288
+ "### Torch DataLoader"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "markdown",
293
+ "id": "17a4bbc6",
294
+ "metadata": {},
295
+ "source": [
296
+ "We'll create a PyTorch DataLoader for convenience. This allows us to easily take batches of our desired size.\n",
297
+ "\n",
298
+ "We won't be using parallel processing of the DataLoader for now, as the items will be retrieved on the fly. We could attempt to do it using these recommendations: https://pytorch.org/docs/stable/data.html#multi-process-data-loading. For performance considerations, please refer to this thread: https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "id": "e1c08b7e",
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "import torch\n",
309
+ "from torch.utils.data import DataLoader"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": null,
315
+ "id": "6a296677",
316
+ "metadata": {},
317
+ "outputs": [],
318
+ "source": [
319
+ "torch_dataset = prepared_dataset.with_format(\"torch\")"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "markdown",
324
+ "id": "29ab13bc",
325
+ "metadata": {},
326
+ "source": [
327
+ "**Note**: according to my tests, `num_workers` is not compatible with Datasets in streaming mode. Processes deadlock and there's no progress."
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": null,
333
+ "id": "e2df5e13",
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "dataloader = DataLoader(torch_dataset, batch_size=batch_size * jax.device_count())"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "id": "c15e3783",
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": [
347
+ "batch = next(iter(dataloader))"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": null,
353
+ "id": "71d027fe",
354
+ "metadata": {},
355
+ "outputs": [],
356
+ "source": [
357
+ "batch['image'].shape"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "markdown",
362
+ "id": "a354472b",
363
+ "metadata": {},
364
+ "source": [
365
+ "## VQGAN-JAX model"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": null,
371
+ "id": "2fcf01d7",
372
+ "metadata": {},
373
+ "outputs": [],
374
+ "source": [
375
+ "from vqgan_jax.modeling_flax_vqgan import VQModel"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "markdown",
380
+ "id": "9daa636d",
381
+ "metadata": {},
382
+ "source": [
383
+ "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "code",
388
+ "execution_count": null,
389
+ "id": "47a8b818",
390
+ "metadata": {
391
+ "scrolled": true
392
+ },
393
+ "outputs": [],
394
+ "source": [
395
+ "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "markdown",
400
+ "id": "62ad01c3",
401
+ "metadata": {},
402
+ "source": [
403
+ "## Encoding"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "markdown",
408
+ "id": "20357f74",
409
+ "metadata": {},
410
+ "source": [
411
+ "Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": null,
417
+ "id": "6686b004",
418
+ "metadata": {},
419
+ "outputs": [],
420
+ "source": [
421
+ "from flax.training.common_utils import shard\n",
422
+ "from functools import partial"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "id": "322a4619",
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
433
+ "def encode(batch):\n",
434
+ " # Not sure if we should `replicate` params, does not seem to have any effect\n",
435
+ " _, indices = model.encode(batch)\n",
436
+ " return indices"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "markdown",
441
+ "id": "14375a41",
442
+ "metadata": {},
443
+ "source": [
444
+ "### Encoding loop"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "id": "ff6c10d4",
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "import os\n",
455
+ "import pandas as pd\n",
456
+ "\n",
457
+ "def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
458
+ " output_dir.mkdir(parents=True, exist_ok=True)\n",
459
+ " \n",
460
+ " # Saving strategy:\n",
461
+ " # - Create a new file every so often to prevent excessive file seeking.\n",
462
+ " # - Save each batch after processing.\n",
463
+ " # - Keep the file open until we are done with it.\n",
464
+ " file = None \n",
465
+ " for n, batch in enumerate(tqdm(iter(dataloader))):\n",
466
+ " if (n % save_every == 0):\n",
467
+ " if file is not None:\n",
468
+ " file.close()\n",
469
+ " split_num = n // save_every\n",
470
+ " file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
471
+ "\n",
472
+ " images = batch[\"image\"].numpy()\n",
473
+ " images = shard(images.squeeze())\n",
474
+ " encoded = encode(images)\n",
475
+ " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
476
+ "\n",
477
+ " keys = batch[\"key\"]\n",
478
+ " captions = batch[\"caption\"]\n",
479
+ "\n",
480
+ " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
481
+ " batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
482
+ " batch_df.to_json(file, orient='records', lines=True)"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "markdown",
487
+ "id": "09ff75a3",
488
+ "metadata": {},
489
+ "source": [
490
+ "Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": null,
496
+ "id": "96222bb4",
497
+ "metadata": {},
498
+ "outputs": [],
499
+ "source": [
500
+ "save_every = 318"
501
+ ]
502
+ },
503
+ {
504
+ "cell_type": "code",
505
+ "execution_count": null,
506
+ "id": "7704863d",
507
+ "metadata": {},
508
+ "outputs": [
509
+ {
510
+ "name": "stderr",
511
+ "output_type": "stream",
512
+ "text": [
513
+ "28it [01:17, 1.60s/it]"
514
+ ]
515
+ }
516
+ ],
517
+ "source": [
518
+ "encode_captioned_dataset(dataloader, yfcc100m_output, save_every=save_every)"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "markdown",
523
+ "id": "e266a70a",
524
+ "metadata": {},
525
+ "source": [
526
+ "This is ~10-15 slower than local encoding from an SSD. For performance considerations, see the discussion at https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13."
527
+ ]
528
+ },
529
+ {
530
+ "cell_type": "markdown",
531
+ "id": "8953dd84",
532
+ "metadata": {},
533
+ "source": [
534
+ "----"
535
+ ]
536
+ }
537
+ ],
538
+ "metadata": {
539
+ "interpreter": {
540
+ "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
541
+ },
542
+ "kernelspec": {
543
+ "display_name": "Python 3 (ipykernel)",
544
+ "language": "python",
545
+ "name": "python3"
546
+ },
547
+ "language_info": {
548
+ "codemirror_mode": {
549
+ "name": "ipython",
550
+ "version": 3
551
+ },
552
+ "file_extension": ".py",
553
+ "mimetype": "text/x-python",
554
+ "name": "python",
555
+ "nbconvert_exporter": "python",
556
+ "pygments_lexer": "ipython3",
557
+ "version": "3.8.10"
558
+ }
559
+ },
560
+ "nbformat": 4,
561
+ "nbformat_minor": 5
562
+ }
dev/encoding/vqgan-jax-encoding-webdataset.ipynb ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0b72877",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VQGAN JAX Encoding for `webdataset`"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "ba7b31e6",
14
+ "metadata": {},
15
+ "source": [
16
+ "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
17
+ "\n",
18
+ "This example uses a small subset of YFCC100M we created for testing, but it should be easy to adapt to any other image/caption dataset in the `webdataset` format."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "3b59489e",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import numpy as np\n",
29
+ "from tqdm import tqdm\n",
30
+ "\n",
31
+ "import torch\n",
32
+ "import torchvision.transforms as T\n",
33
+ "import torchvision.transforms.functional as TF\n",
34
+ "from torchvision.transforms import InterpolationMode\n",
35
+ "import math\n",
36
+ "\n",
37
+ "import webdataset as wds\n",
38
+ "\n",
39
+ "import jax\n",
40
+ "from jax import pmap"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "id": "c7c4c1e6",
46
+ "metadata": {},
47
+ "source": [
48
+ "## Dataset and Parameters"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "markdown",
53
+ "id": "9822850f",
54
+ "metadata": {},
55
+ "source": [
56
+ "The following is the list of shards we'll process. We hardcode the length of data so that we can see nice progress bars using `tqdm`."
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "id": "1265dbfe",
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "shards = 'https://huggingface.co/datasets/dalle-mini/YFCC100M_OpenAI_subset/resolve/main/data/shard-{0000..0008}.tar'\n",
67
+ "length = 8320"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "markdown",
72
+ "id": "7e38fa14",
73
+ "metadata": {},
74
+ "source": [
75
+ "If we are extra cautious or our server is unreliable, we can enable retries by providing a custom `curl` retrieval command:"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "id": "4c8c5960",
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "# Enable curl retries to try to work around temporary network / server errors.\n",
86
+ "# This shouldn't be necessary when using reliable servers.\n",
87
+ "# shards = f'pipe:curl -s --retry 5 --retry-delay 5 -L {shards} || true'"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "13c6631b",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "from pathlib import Path\n",
98
+ "\n",
99
+ "# Output directory for encoded files\n",
100
+ "encoded_output = Path.home()/'data'/'wds'/'encoded'\n",
101
+ "\n",
102
+ "batch_size = 128 # Per device\n",
103
+ "num_workers = 8 # For parallel processing"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "id": "3435fb85",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
114
+ "batches = math.ceil(length / bs)"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "markdown",
119
+ "id": "88598e4b",
120
+ "metadata": {},
121
+ "source": [
122
+ "Image processing"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "id": "669b35df",
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "def center_crop(image, max_size=256):\n",
133
+ " # Note: we allow upscaling too. We should exclude small images. \n",
134
+ " image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
135
+ " image = TF.center_crop(image, output_size=2 * [max_size])\n",
136
+ " return image\n",
137
+ "\n",
138
+ "preprocess_image = T.Compose([\n",
139
+ " center_crop,\n",
140
+ " T.ToTensor(),\n",
141
+ " lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n",
142
+ "])"
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "markdown",
147
+ "id": "a185e90c",
148
+ "metadata": {},
149
+ "source": [
150
+ "Caption preparation.\n",
151
+ "\n",
152
+ "Note that we receive the contents of the `json` structure, which will be replaced by the string we return.\n",
153
+ "If we want to keep other fields inside `json`, we can add `caption` as a new field."
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "id": "423ee10e",
160
+ "metadata": {},
161
+ "outputs": [],
162
+ "source": [
163
+ "def create_caption(item):\n",
164
+ " title = item['title_clean'].strip()\n",
165
+ " description = item['description_clean'].strip()\n",
166
+ " if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
167
+ " return f'{title} {description}'"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "markdown",
172
+ "id": "8d3a95db",
173
+ "metadata": {},
174
+ "source": [
175
+ "When an error occurs (a download is disconnected, an image cannot be decoded, etc) the process stops with an exception. We can use one of the exception handlers provided by the `webdataset` library, such as `wds.warn_and_continue` or `wds.ignore_and_continue` to ignore the offending entry and keep iterating.\n",
176
+ "\n",
177
+ "**IMPORTANT WARNING:** Do not use error handlers to ignore exceptions until you have tested that your processing pipeline works fine. Otherwise, the process will continue trying to find a valid entry, and it will consume your whole dataset without doing any work.\n",
178
+ "\n",
179
+ "We can also create our custom exception handler as demonstrated here:"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "id": "369d9719",
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "# UNUSED - Log exceptions to a file\n",
190
+ "def ignore_and_log(exn):\n",
191
+ " with open('errors.txt', 'a') as f:\n",
192
+ " f.write(f'{repr(exn)}\\n')\n",
193
+ " return True"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "id": "27de1414",
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "# Or simply use `wds.ignore_and_continue`\n",
204
+ "exception_handler = wds.warn_and_continue"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": null,
210
+ "id": "5149b6d5",
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": [
214
+ "dataset = wds.WebDataset(shards,\n",
215
+ " length=batches, # Hint so `len` is implemented\n",
216
+ " shardshuffle=False, # Keep same order for encoded files for easier bookkeeping. Set to `True` for training.\n",
217
+ " handler=exception_handler, # Ignore read errors instead of failing.\n",
218
+ ")\n",
219
+ "\n",
220
+ "dataset = (dataset \n",
221
+ " .decode('pil') # decode image with PIL\n",
222
+ "# .map_dict(jpg=preprocess_image, json=create_caption, handler=exception_handler) # Process fields with functions defined above\n",
223
+ " .map_dict(jpg=preprocess_image, json=create_caption) # Process fields with functions defined above\n",
224
+ " .to_tuple('__key__', 'jpg', 'json') # filter to keep only key (for reference), image, caption.\n",
225
+ " .batched(bs)) # better to batch in the dataset (but we could also do it in the dataloader) - this arg does not affect speed and we could remove it"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "8cac98cb",
232
+ "metadata": {
233
+ "scrolled": true
234
+ },
235
+ "outputs": [],
236
+ "source": [
237
+ "%%time\n",
238
+ "keys, images, captions = next(iter(dataset))"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "id": "cd268fbf",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "images.shape"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "id": "c24693c0",
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "T.ToPILImage()(images[0].permute(2, 0, 1))"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "markdown",
263
+ "id": "44d50a51",
264
+ "metadata": {},
265
+ "source": [
266
+ "### Torch DataLoader"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "id": "e2df5e13",
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": [
276
+ "dl = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=num_workers)"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "markdown",
281
+ "id": "a354472b",
282
+ "metadata": {},
283
+ "source": [
284
+ "## VQGAN-JAX model"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "id": "2fcf01d7",
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "from vqgan_jax.modeling_flax_vqgan import VQModel"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "markdown",
299
+ "id": "9daa636d",
300
+ "metadata": {},
301
+ "source": [
302
+ "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": null,
308
+ "id": "47a8b818",
309
+ "metadata": {
310
+ "scrolled": true
311
+ },
312
+ "outputs": [],
313
+ "source": [
314
+ "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "markdown",
319
+ "id": "62ad01c3",
320
+ "metadata": {},
321
+ "source": [
322
+ "## Encoding"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "markdown",
327
+ "id": "20357f74",
328
+ "metadata": {},
329
+ "source": [
330
+ "Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": null,
336
+ "id": "6686b004",
337
+ "metadata": {},
338
+ "outputs": [],
339
+ "source": [
340
+ "from flax.training.common_utils import shard\n",
341
+ "from functools import partial"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": null,
347
+ "id": "322a4619",
348
+ "metadata": {},
349
+ "outputs": [],
350
+ "source": [
351
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
352
+ "def encode(batch):\n",
353
+ " # Not sure if we should `replicate` params, does not seem to have any effect\n",
354
+ " _, indices = model.encode(batch)\n",
355
+ " return indices"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "markdown",
360
+ "id": "14375a41",
361
+ "metadata": {},
362
+ "source": [
363
+ "### Encoding loop"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": null,
369
+ "id": "ff6c10d4",
370
+ "metadata": {},
371
+ "outputs": [],
372
+ "source": [
373
+ "import os\n",
374
+ "import pandas as pd\n",
375
+ "\n",
376
+ "def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
377
+ " output_dir.mkdir(parents=True, exist_ok=True)\n",
378
+ "\n",
379
+ " # Saving strategy:\n",
380
+ " # - Create a new file every so often to prevent excessive file seeking.\n",
381
+ " # - Save each batch after processing.\n",
382
+ " # - Keep the file open until we are done with it.\n",
383
+ " file = None \n",
384
+ " for n, (keys, images, captions) in enumerate(tqdm(dataloader)):\n",
385
+ " if (n % save_every == 0):\n",
386
+ " if file is not None:\n",
387
+ " file.close()\n",
388
+ " split_num = n // save_every\n",
389
+ " file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
390
+ "\n",
391
+ " images = shard(images.numpy().squeeze())\n",
392
+ " encoded = encode(images)\n",
393
+ " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
394
+ "\n",
395
+ " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
396
+ " batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
397
+ " batch_df.to_json(file, orient='records', lines=True)"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "markdown",
402
+ "id": "09ff75a3",
403
+ "metadata": {},
404
+ "source": [
405
+ "Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "code",
410
+ "execution_count": null,
411
+ "id": "96222bb4",
412
+ "metadata": {},
413
+ "outputs": [],
414
+ "source": [
415
+ "save_every = 318"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": null,
421
+ "id": "7704863d",
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "encode_captioned_dataset(dl, encoded_output, save_every=save_every)"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "markdown",
430
+ "id": "8953dd84",
431
+ "metadata": {},
432
+ "source": [
433
+ "----"
434
+ ]
435
+ }
436
+ ],
437
+ "metadata": {
438
+ "interpreter": {
439
+ "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
440
+ },
441
+ "kernelspec": {
442
+ "display_name": "Python 3 (ipykernel)",
443
+ "language": "python",
444
+ "name": "python3"
445
+ },
446
+ "language_info": {
447
+ "codemirror_mode": {
448
+ "name": "ipython",
449
+ "version": 3
450
+ },
451
+ "file_extension": ".py",
452
+ "mimetype": "text/x-python",
453
+ "name": "python",
454
+ "nbconvert_exporter": "python",
455
+ "pygments_lexer": "ipython3",
456
+ "version": "3.8.10"
457
+ }
458
+ },
459
+ "nbformat": 4,
460
+ "nbformat_minor": 5
461
+ }
dev/inference/dalle_mini DELETED
@@ -1 +0,0 @@
1
- ../../dalle_mini
 
 
dev/inference/inference_pipeline.ipynb CHANGED
@@ -6,7 +6,7 @@
6
  "name": "DALL·E mini - Inference pipeline.ipynb",
7
  "provenance": [],
8
  "collapsed_sections": [],
9
- "authorship_tag": "ABX9TyOmaisFwTAYRR7mJmMVxzdA",
10
  "include_colab_link": true
11
  },
12
  "kernelspec": {
@@ -22,6 +22,7 @@
22
  "49304912717a4995ae45d04a59d1f50e": {
23
  "model_module": "@jupyter-widgets/controls",
24
  "model_name": "HBoxModel",
 
25
  "state": {
26
  "_view_name": "HBoxView",
27
  "_dom_classes": [],
@@ -42,6 +43,7 @@
42
  "5fd9f97986024e8db560a6737ade9e2e": {
43
  "model_module": "@jupyter-widgets/base",
44
  "model_name": "LayoutModel",
 
45
  "state": {
46
  "_view_name": "LayoutView",
47
  "grid_template_rows": null,
@@ -93,6 +95,7 @@
93
  "caced43e3a4c493b98fb07cb41db045c": {
94
  "model_module": "@jupyter-widgets/controls",
95
  "model_name": "FloatProgressModel",
 
96
  "state": {
97
  "_view_name": "ProgressView",
98
  "style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
@@ -116,6 +119,7 @@
116
  "0acc161f2e9948b68b3fc4e57ef333c9": {
117
  "model_module": "@jupyter-widgets/controls",
118
  "model_name": "HTMLModel",
 
119
  "state": {
120
  "_view_name": "HTMLView",
121
  "style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
@@ -136,6 +140,7 @@
136
  "40c54b9454d346aabd197f2bcf189467": {
137
  "model_module": "@jupyter-widgets/controls",
138
  "model_name": "ProgressStyleModel",
 
139
  "state": {
140
  "_view_name": "StyleView",
141
  "_model_name": "ProgressStyleModel",
@@ -151,6 +156,7 @@
151
  "8b25334a48244a14aa9ba0176887e655": {
152
  "model_module": "@jupyter-widgets/base",
153
  "model_name": "LayoutModel",
 
154
  "state": {
155
  "_view_name": "LayoutView",
156
  "grid_template_rows": null,
@@ -202,6 +208,7 @@
202
  "7e7c488f57fc4acb8d261e2db81d61f0": {
203
  "model_module": "@jupyter-widgets/controls",
204
  "model_name": "DescriptionStyleModel",
 
205
  "state": {
206
  "_view_name": "StyleView",
207
  "_model_name": "DescriptionStyleModel",
@@ -216,6 +223,7 @@
216
  "72c401062a5348b1a366dffb5a403568": {
217
  "model_module": "@jupyter-widgets/base",
218
  "model_name": "LayoutModel",
 
219
  "state": {
220
  "_view_name": "LayoutView",
221
  "grid_template_rows": null,
@@ -267,6 +275,7 @@
267
  "022c124dfff348f285335732781b0887": {
268
  "model_module": "@jupyter-widgets/controls",
269
  "model_name": "HBoxModel",
 
270
  "state": {
271
  "_view_name": "HBoxView",
272
  "_dom_classes": [],
@@ -287,6 +296,7 @@
287
  "a44e47e9d26c4deb81a5a11a9db92a9f": {
288
  "model_module": "@jupyter-widgets/base",
289
  "model_name": "LayoutModel",
 
290
  "state": {
291
  "_view_name": "LayoutView",
292
  "grid_template_rows": null,
@@ -338,6 +348,7 @@
338
  "cd9c7016caae47c1b41fb2608c78b0bf": {
339
  "model_module": "@jupyter-widgets/controls",
340
  "model_name": "FloatProgressModel",
 
341
  "state": {
342
  "_view_name": "ProgressView",
343
  "style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
@@ -361,6 +372,7 @@
361
  "36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
362
  "model_module": "@jupyter-widgets/controls",
363
  "model_name": "HTMLModel",
 
364
  "state": {
365
  "_view_name": "HTMLView",
366
  "style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
@@ -381,6 +393,7 @@
381
  "c22f207311cf4fb69bd9328eabfd4ebb": {
382
  "model_module": "@jupyter-widgets/controls",
383
  "model_name": "ProgressStyleModel",
 
384
  "state": {
385
  "_view_name": "StyleView",
386
  "_model_name": "ProgressStyleModel",
@@ -396,6 +409,7 @@
396
  "5a38c6d83a264bedbf7efe6e97eba953": {
397
  "model_module": "@jupyter-widgets/base",
398
  "model_name": "LayoutModel",
 
399
  "state": {
400
  "_view_name": "LayoutView",
401
  "grid_template_rows": null,
@@ -447,6 +461,7 @@
447
  "037563a7eadd4ac5abb7249a2914d346": {
448
  "model_module": "@jupyter-widgets/controls",
449
  "model_name": "DescriptionStyleModel",
 
450
  "state": {
451
  "_view_name": "StyleView",
452
  "_model_name": "DescriptionStyleModel",
@@ -461,6 +476,7 @@
461
  "3975e7ed0b704990b1fa05909a9bb9b6": {
462
  "model_module": "@jupyter-widgets/base",
463
  "model_name": "LayoutModel",
 
464
  "state": {
465
  "_view_name": "LayoutView",
466
  "grid_template_rows": null,
@@ -512,6 +528,7 @@
512
  "f9f1fdc3819a4142b85304cd3c6358a2": {
513
  "model_module": "@jupyter-widgets/controls",
514
  "model_name": "HBoxModel",
 
515
  "state": {
516
  "_view_name": "HBoxView",
517
  "_dom_classes": [],
@@ -532,6 +549,7 @@
532
  "ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
533
  "model_module": "@jupyter-widgets/base",
534
  "model_name": "LayoutModel",
 
535
  "state": {
536
  "_view_name": "LayoutView",
537
  "grid_template_rows": null,
@@ -583,6 +601,7 @@
583
  "29d42e94b3b34c86a117b623da68faed": {
584
  "model_module": "@jupyter-widgets/controls",
585
  "model_name": "FloatProgressModel",
 
586
  "state": {
587
  "_view_name": "ProgressView",
588
  "style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
@@ -606,6 +625,7 @@
606
  "8b73de7dbdfe40dbbb39fb593520b984": {
607
  "model_module": "@jupyter-widgets/controls",
608
  "model_name": "HTMLModel",
 
609
  "state": {
610
  "_view_name": "HTMLView",
611
  "style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
@@ -626,6 +646,7 @@
626
  "8ce4d20d004a4382afa0abdd3b1f7191": {
627
  "model_module": "@jupyter-widgets/controls",
628
  "model_name": "ProgressStyleModel",
 
629
  "state": {
630
  "_view_name": "StyleView",
631
  "_model_name": "ProgressStyleModel",
@@ -641,6 +662,7 @@
641
  "efc4812245c8459c92e6436889b4f600": {
642
  "model_module": "@jupyter-widgets/base",
643
  "model_name": "LayoutModel",
 
644
  "state": {
645
  "_view_name": "LayoutView",
646
  "grid_template_rows": null,
@@ -692,6 +714,7 @@
692
  "717ccef4df1f477abb51814650eb47da": {
693
  "model_module": "@jupyter-widgets/controls",
694
  "model_name": "DescriptionStyleModel",
 
695
  "state": {
696
  "_view_name": "StyleView",
697
  "_model_name": "DescriptionStyleModel",
@@ -706,6 +729,7 @@
706
  "7dba58f0391c485a86e34e8039ec6189": {
707
  "model_module": "@jupyter-widgets/base",
708
  "model_name": "LayoutModel",
 
709
  "state": {
710
  "_view_name": "LayoutView",
711
  "grid_template_rows": null,
@@ -804,8 +828,7 @@
804
  "source": [
805
  "!pip install -q transformers flax\n",
806
  "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX\n",
807
- "!git clone https://github.com/borisdayma/dalle-mini # Model files\n",
808
- "%cd dalle-mini/"
809
  ],
810
  "execution_count": null,
811
  "outputs": []
@@ -833,7 +856,7 @@
833
  "import random\n",
834
  "from tqdm.notebook import tqdm, trange"
835
  ],
836
- "execution_count": 2,
837
  "outputs": []
838
  },
839
  {
@@ -846,7 +869,7 @@
846
  "DALLE_REPO = 'flax-community/dalle-mini'\n",
847
  "DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
848
  ],
849
- "execution_count": 3,
850
  "outputs": []
851
  },
852
  {
@@ -871,7 +894,7 @@
871
  "# set a prompt\n",
872
  "prompt = 'picture of a waterfall under the sunset'"
873
  ],
874
- "execution_count": 5,
875
  "outputs": []
876
  },
877
  {
@@ -888,7 +911,7 @@
888
  "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
889
  "tokenized_prompt"
890
  ],
891
- "execution_count": 6,
892
  "outputs": [
893
  {
894
  "output_type": "execute_result",
@@ -956,7 +979,7 @@
956
  "subkeys = jax.random.split(key, num=n_predictions)\n",
957
  "subkeys"
958
  ],
959
- "execution_count": 7,
960
  "outputs": [
961
  {
962
  "output_type": "execute_result",
@@ -1004,7 +1027,7 @@
1004
  "encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
1005
  "encoded_images[0]"
1006
  ],
1007
- "execution_count": 8,
1008
  "outputs": [
1009
  {
1010
  "output_type": "display_data",
@@ -1099,7 +1122,7 @@
1099
  "encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
1100
  "encoded_images[0]"
1101
  ],
1102
- "execution_count": 9,
1103
  "outputs": [
1104
  {
1105
  "output_type": "execute_result",
@@ -1167,7 +1190,7 @@
1167
  "source": [
1168
  "encoded_images[0].shape"
1169
  ],
1170
- "execution_count": 10,
1171
  "outputs": [
1172
  {
1173
  "output_type": "execute_result",
@@ -1204,7 +1227,7 @@
1204
  "import numpy as np\n",
1205
  "from PIL import Image"
1206
  ],
1207
- "execution_count": 11,
1208
  "outputs": []
1209
  },
1210
  {
@@ -1217,7 +1240,7 @@
1217
  "VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
1218
  "VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
1219
  ],
1220
- "execution_count": 12,
1221
  "outputs": []
1222
  },
1223
  {
@@ -1233,7 +1256,7 @@
1233
  "# set up VQGAN\n",
1234
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
1235
  ],
1236
- "execution_count": 13,
1237
  "outputs": [
1238
  {
1239
  "output_type": "stream",
@@ -1269,7 +1292,7 @@
1269
  "decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
1270
  "decoded_images[0]"
1271
  ],
1272
- "execution_count": 14,
1273
  "outputs": [
1274
  {
1275
  "output_type": "display_data",
@@ -1373,7 +1396,7 @@
1373
  "# normalize images\n",
1374
  "clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
1375
  ],
1376
- "execution_count": 15,
1377
  "outputs": []
1378
  },
1379
  {
@@ -1385,7 +1408,7 @@
1385
  "# convert to image\n",
1386
  "images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
1387
  ],
1388
- "execution_count": 16,
1389
  "outputs": []
1390
  },
1391
  {
@@ -1402,7 +1425,7 @@
1402
  "# display an image\n",
1403
  "images[0]"
1404
  ],
1405
- "execution_count": 17,
1406
  "outputs": [
1407
  {
1408
  "output_type": "execute_result",
@@ -1438,7 +1461,7 @@
1438
  "source": [
1439
  "from transformers import CLIPProcessor, FlaxCLIPModel"
1440
  ],
1441
- "execution_count": 18,
1442
  "outputs": []
1443
  },
1444
  {
@@ -1474,7 +1497,7 @@
1474
  "logits = clip(**inputs).logits_per_image\n",
1475
  "scores = jax.nn.softmax(logits, axis=0).squeeze() # normalize and sum all scores to 1"
1476
  ],
1477
- "execution_count": 20,
1478
  "outputs": []
1479
  },
1480
  {
@@ -1495,7 +1518,7 @@
1495
  " display(images[idx])\n",
1496
  " print()"
1497
  ],
1498
- "execution_count": 21,
1499
  "outputs": [
1500
  {
1501
  "output_type": "stream",
@@ -1690,7 +1713,7 @@
1690
  "from flax.training.common_utils import shard\n",
1691
  "from flax.jax_utils import replicate"
1692
  ],
1693
- "execution_count": 22,
1694
  "outputs": []
1695
  },
1696
  {
@@ -1706,7 +1729,7 @@
1706
  "# check we can access TPU's or GPU's\n",
1707
  "jax.devices()"
1708
  ],
1709
- "execution_count": 23,
1710
  "outputs": [
1711
  {
1712
  "output_type": "execute_result",
@@ -1744,7 +1767,7 @@
1744
  "# one set of inputs per device\n",
1745
  "prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
1746
  ],
1747
- "execution_count": 25,
1748
  "outputs": []
1749
  },
1750
  {
@@ -1757,7 +1780,7 @@
1757
  "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
1758
  "tokenized_prompt = shard(tokenized_prompt)"
1759
  ],
1760
- "execution_count": 26,
1761
  "outputs": []
1762
  },
1763
  {
@@ -1793,7 +1816,7 @@
1793
  "def p_decode(indices, params):\n",
1794
  " return vqgan.decode_code(indices, params=params)"
1795
  ],
1796
- "execution_count": 27,
1797
  "outputs": []
1798
  },
1799
  {
@@ -1834,7 +1857,7 @@
1834
  " for img in decoded_images:\n",
1835
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
1836
  ],
1837
- "execution_count": 28,
1838
  "outputs": [
1839
  {
1840
  "output_type": "display_data",
@@ -1877,7 +1900,7 @@
1877
  " display(img)\n",
1878
  " print()"
1879
  ],
1880
- "execution_count": 29,
1881
  "outputs": [
1882
  {
1883
  "output_type": "display_data",
 
6
  "name": "DALL·E mini - Inference pipeline.ipynb",
7
  "provenance": [],
8
  "collapsed_sections": [],
9
+ "authorship_tag": "ABX9TyMUjEt1XMLq+6/GhSnVFsSx",
10
  "include_colab_link": true
11
  },
12
  "kernelspec": {
 
22
  "49304912717a4995ae45d04a59d1f50e": {
23
  "model_module": "@jupyter-widgets/controls",
24
  "model_name": "HBoxModel",
25
+ "model_module_version": "1.5.0",
26
  "state": {
27
  "_view_name": "HBoxView",
28
  "_dom_classes": [],
 
43
  "5fd9f97986024e8db560a6737ade9e2e": {
44
  "model_module": "@jupyter-widgets/base",
45
  "model_name": "LayoutModel",
46
+ "model_module_version": "1.2.0",
47
  "state": {
48
  "_view_name": "LayoutView",
49
  "grid_template_rows": null,
 
95
  "caced43e3a4c493b98fb07cb41db045c": {
96
  "model_module": "@jupyter-widgets/controls",
97
  "model_name": "FloatProgressModel",
98
+ "model_module_version": "1.5.0",
99
  "state": {
100
  "_view_name": "ProgressView",
101
  "style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
 
119
  "0acc161f2e9948b68b3fc4e57ef333c9": {
120
  "model_module": "@jupyter-widgets/controls",
121
  "model_name": "HTMLModel",
122
+ "model_module_version": "1.5.0",
123
  "state": {
124
  "_view_name": "HTMLView",
125
  "style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
 
140
  "40c54b9454d346aabd197f2bcf189467": {
141
  "model_module": "@jupyter-widgets/controls",
142
  "model_name": "ProgressStyleModel",
143
+ "model_module_version": "1.5.0",
144
  "state": {
145
  "_view_name": "StyleView",
146
  "_model_name": "ProgressStyleModel",
 
156
  "8b25334a48244a14aa9ba0176887e655": {
157
  "model_module": "@jupyter-widgets/base",
158
  "model_name": "LayoutModel",
159
+ "model_module_version": "1.2.0",
160
  "state": {
161
  "_view_name": "LayoutView",
162
  "grid_template_rows": null,
 
208
  "7e7c488f57fc4acb8d261e2db81d61f0": {
209
  "model_module": "@jupyter-widgets/controls",
210
  "model_name": "DescriptionStyleModel",
211
+ "model_module_version": "1.5.0",
212
  "state": {
213
  "_view_name": "StyleView",
214
  "_model_name": "DescriptionStyleModel",
 
223
  "72c401062a5348b1a366dffb5a403568": {
224
  "model_module": "@jupyter-widgets/base",
225
  "model_name": "LayoutModel",
226
+ "model_module_version": "1.2.0",
227
  "state": {
228
  "_view_name": "LayoutView",
229
  "grid_template_rows": null,
 
275
  "022c124dfff348f285335732781b0887": {
276
  "model_module": "@jupyter-widgets/controls",
277
  "model_name": "HBoxModel",
278
+ "model_module_version": "1.5.0",
279
  "state": {
280
  "_view_name": "HBoxView",
281
  "_dom_classes": [],
 
296
  "a44e47e9d26c4deb81a5a11a9db92a9f": {
297
  "model_module": "@jupyter-widgets/base",
298
  "model_name": "LayoutModel",
299
+ "model_module_version": "1.2.0",
300
  "state": {
301
  "_view_name": "LayoutView",
302
  "grid_template_rows": null,
 
348
  "cd9c7016caae47c1b41fb2608c78b0bf": {
349
  "model_module": "@jupyter-widgets/controls",
350
  "model_name": "FloatProgressModel",
351
+ "model_module_version": "1.5.0",
352
  "state": {
353
  "_view_name": "ProgressView",
354
  "style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
 
372
  "36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
373
  "model_module": "@jupyter-widgets/controls",
374
  "model_name": "HTMLModel",
375
+ "model_module_version": "1.5.0",
376
  "state": {
377
  "_view_name": "HTMLView",
378
  "style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
 
393
  "c22f207311cf4fb69bd9328eabfd4ebb": {
394
  "model_module": "@jupyter-widgets/controls",
395
  "model_name": "ProgressStyleModel",
396
+ "model_module_version": "1.5.0",
397
  "state": {
398
  "_view_name": "StyleView",
399
  "_model_name": "ProgressStyleModel",
 
409
  "5a38c6d83a264bedbf7efe6e97eba953": {
410
  "model_module": "@jupyter-widgets/base",
411
  "model_name": "LayoutModel",
412
+ "model_module_version": "1.2.0",
413
  "state": {
414
  "_view_name": "LayoutView",
415
  "grid_template_rows": null,
 
461
  "037563a7eadd4ac5abb7249a2914d346": {
462
  "model_module": "@jupyter-widgets/controls",
463
  "model_name": "DescriptionStyleModel",
464
+ "model_module_version": "1.5.0",
465
  "state": {
466
  "_view_name": "StyleView",
467
  "_model_name": "DescriptionStyleModel",
 
476
  "3975e7ed0b704990b1fa05909a9bb9b6": {
477
  "model_module": "@jupyter-widgets/base",
478
  "model_name": "LayoutModel",
479
+ "model_module_version": "1.2.0",
480
  "state": {
481
  "_view_name": "LayoutView",
482
  "grid_template_rows": null,
 
528
  "f9f1fdc3819a4142b85304cd3c6358a2": {
529
  "model_module": "@jupyter-widgets/controls",
530
  "model_name": "HBoxModel",
531
+ "model_module_version": "1.5.0",
532
  "state": {
533
  "_view_name": "HBoxView",
534
  "_dom_classes": [],
 
549
  "ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
550
  "model_module": "@jupyter-widgets/base",
551
  "model_name": "LayoutModel",
552
+ "model_module_version": "1.2.0",
553
  "state": {
554
  "_view_name": "LayoutView",
555
  "grid_template_rows": null,
 
601
  "29d42e94b3b34c86a117b623da68faed": {
602
  "model_module": "@jupyter-widgets/controls",
603
  "model_name": "FloatProgressModel",
604
+ "model_module_version": "1.5.0",
605
  "state": {
606
  "_view_name": "ProgressView",
607
  "style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
 
625
  "8b73de7dbdfe40dbbb39fb593520b984": {
626
  "model_module": "@jupyter-widgets/controls",
627
  "model_name": "HTMLModel",
628
+ "model_module_version": "1.5.0",
629
  "state": {
630
  "_view_name": "HTMLView",
631
  "style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
 
646
  "8ce4d20d004a4382afa0abdd3b1f7191": {
647
  "model_module": "@jupyter-widgets/controls",
648
  "model_name": "ProgressStyleModel",
649
+ "model_module_version": "1.5.0",
650
  "state": {
651
  "_view_name": "StyleView",
652
  "_model_name": "ProgressStyleModel",
 
662
  "efc4812245c8459c92e6436889b4f600": {
663
  "model_module": "@jupyter-widgets/base",
664
  "model_name": "LayoutModel",
665
+ "model_module_version": "1.2.0",
666
  "state": {
667
  "_view_name": "LayoutView",
668
  "grid_template_rows": null,
 
714
  "717ccef4df1f477abb51814650eb47da": {
715
  "model_module": "@jupyter-widgets/controls",
716
  "model_name": "DescriptionStyleModel",
717
+ "model_module_version": "1.5.0",
718
  "state": {
719
  "_view_name": "StyleView",
720
  "_model_name": "DescriptionStyleModel",
 
729
  "7dba58f0391c485a86e34e8039ec6189": {
730
  "model_module": "@jupyter-widgets/base",
731
  "model_name": "LayoutModel",
732
+ "model_module_version": "1.2.0",
733
  "state": {
734
  "_view_name": "LayoutView",
735
  "grid_template_rows": null,
 
828
  "source": [
829
  "!pip install -q transformers flax\n",
830
  "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX\n",
831
+ "!pip install -q git+https://github.com/borisdayma/dalle-mini.git # Model files"
 
832
  ],
833
  "execution_count": null,
834
  "outputs": []
 
856
  "import random\n",
857
  "from tqdm.notebook import tqdm, trange"
858
  ],
859
+ "execution_count": null,
860
  "outputs": []
861
  },
862
  {
 
869
  "DALLE_REPO = 'flax-community/dalle-mini'\n",
870
  "DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
871
  ],
872
+ "execution_count": null,
873
  "outputs": []
874
  },
875
  {
 
894
  "# set a prompt\n",
895
  "prompt = 'picture of a waterfall under the sunset'"
896
  ],
897
+ "execution_count": null,
898
  "outputs": []
899
  },
900
  {
 
911
  "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
912
  "tokenized_prompt"
913
  ],
914
+ "execution_count": null,
915
  "outputs": [
916
  {
917
  "output_type": "execute_result",
 
979
  "subkeys = jax.random.split(key, num=n_predictions)\n",
980
  "subkeys"
981
  ],
982
+ "execution_count": null,
983
  "outputs": [
984
  {
985
  "output_type": "execute_result",
 
1027
  "encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
1028
  "encoded_images[0]"
1029
  ],
1030
+ "execution_count": null,
1031
  "outputs": [
1032
  {
1033
  "output_type": "display_data",
 
1122
  "encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
1123
  "encoded_images[0]"
1124
  ],
1125
+ "execution_count": null,
1126
  "outputs": [
1127
  {
1128
  "output_type": "execute_result",
 
1190
  "source": [
1191
  "encoded_images[0].shape"
1192
  ],
1193
+ "execution_count": null,
1194
  "outputs": [
1195
  {
1196
  "output_type": "execute_result",
 
1227
  "import numpy as np\n",
1228
  "from PIL import Image"
1229
  ],
1230
+ "execution_count": null,
1231
  "outputs": []
1232
  },
1233
  {
 
1240
  "VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
1241
  "VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
1242
  ],
1243
+ "execution_count": null,
1244
  "outputs": []
1245
  },
1246
  {
 
1256
  "# set up VQGAN\n",
1257
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
1258
  ],
1259
+ "execution_count": null,
1260
  "outputs": [
1261
  {
1262
  "output_type": "stream",
 
1292
  "decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
1293
  "decoded_images[0]"
1294
  ],
1295
+ "execution_count": null,
1296
  "outputs": [
1297
  {
1298
  "output_type": "display_data",
 
1396
  "# normalize images\n",
1397
  "clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
1398
  ],
1399
+ "execution_count": null,
1400
  "outputs": []
1401
  },
1402
  {
 
1408
  "# convert to image\n",
1409
  "images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
1410
  ],
1411
+ "execution_count": null,
1412
  "outputs": []
1413
  },
1414
  {
 
1425
  "# display an image\n",
1426
  "images[0]"
1427
  ],
1428
+ "execution_count": null,
1429
  "outputs": [
1430
  {
1431
  "output_type": "execute_result",
 
1461
  "source": [
1462
  "from transformers import CLIPProcessor, FlaxCLIPModel"
1463
  ],
1464
+ "execution_count": null,
1465
  "outputs": []
1466
  },
1467
  {
 
1497
  "logits = clip(**inputs).logits_per_image\n",
1498
  "scores = jax.nn.softmax(logits, axis=0).squeeze() # normalize and sum all scores to 1"
1499
  ],
1500
+ "execution_count": null,
1501
  "outputs": []
1502
  },
1503
  {
 
1518
  " display(images[idx])\n",
1519
  " print()"
1520
  ],
1521
+ "execution_count": null,
1522
  "outputs": [
1523
  {
1524
  "output_type": "stream",
 
1713
  "from flax.training.common_utils import shard\n",
1714
  "from flax.jax_utils import replicate"
1715
  ],
1716
+ "execution_count": null,
1717
  "outputs": []
1718
  },
1719
  {
 
1729
  "# check we can access TPU's or GPU's\n",
1730
  "jax.devices()"
1731
  ],
1732
+ "execution_count": null,
1733
  "outputs": [
1734
  {
1735
  "output_type": "execute_result",
 
1767
  "# one set of inputs per device\n",
1768
  "prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
1769
  ],
1770
+ "execution_count": null,
1771
  "outputs": []
1772
  },
1773
  {
 
1780
  "tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
1781
  "tokenized_prompt = shard(tokenized_prompt)"
1782
  ],
1783
+ "execution_count": null,
1784
  "outputs": []
1785
  },
1786
  {
 
1816
  "def p_decode(indices, params):\n",
1817
  " return vqgan.decode_code(indices, params=params)"
1818
  ],
1819
+ "execution_count": null,
1820
  "outputs": []
1821
  },
1822
  {
 
1857
  " for img in decoded_images:\n",
1858
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
1859
  ],
1860
+ "execution_count": null,
1861
  "outputs": [
1862
  {
1863
  "output_type": "display_data",
 
1900
  " display(img)\n",
1901
  " print()"
1902
  ],
1903
+ "execution_count": null,
1904
  "outputs": [
1905
  {
1906
  "output_type": "display_data",
dev/requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
- # Note: install with the following command:
2
- # pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
3
- # Otherwise it won't find the appropriate libtpu_nightly
4
  requests
 
5
  jax[tpu]>=0.2.16
6
- -e git+https://github.com/huggingface/transformers.git@master#egg=transformers
7
- -e git+https://github.com/huggingface/datasets.git@master#egg=datasets
8
  flax
9
  jupyter
10
  wandb
 
 
 
 
1
  requests
2
+ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
3
  jax[tpu]>=0.2.16
4
+ transformers
5
+ datasets
6
  flax
7
  jupyter
8
  wandb
requirements.txt DELETED
@@ -1,2 +0,0 @@
1
- # Requirements for huggingface spaces
2
- streamlit>=0.84.2
 
 
 
setup.cfg ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [metadata]
2
+ name = dalle_mini
3
+ version = attr: dalle_mini.__version__
4
+ description = DALL·E mini - Generate images from a text prompt
5
+ long_description = file: README.md
6
+ long_description_content_type = text/markdown
7
+ url = https://github.com/borisdayma/dalle-mini
8
+ project_urls =
9
+ Bug Tracker = https://github.com/borisdayma/dalle-mini/issues
10
+
11
+ [options]
12
+ packages = find:
13
+ install_requires =
14
+ transformers
15
+ jax
16
+ flax
setup.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ if __name__ == "__main__":
4
+ setup()