Spaces:
Running
Running
Merge branch 'main' of https://github.com/borisdayma/dalle-mini into fix-opt_state
Browse files- .github/workflows/sync_to_hub_debug.yml +17 -0
- CITATION.cff +44 -0
- README.md +76 -7
- app/app.py +22 -4
- app/dalle_mini +0 -1
- app/gradio/dalle_mini +0 -1
- app/img/loading.gif +0 -0
- dalle_mini/text.py +272 -0
- dev/README.md +122 -0
- dev/encoding/vqgan-jax-encoding-streaming.ipynb +562 -0
- dev/encoding/vqgan-jax-encoding-webdataset.ipynb +461 -0
- dev/inference/dalle_mini +0 -1
- dev/inference/inference_pipeline.ipynb +51 -28
- dev/requirements.txt +3 -5
- requirements.txt +0 -2
- setup.cfg +16 -0
- setup.py +4 -0
.github/workflows/sync_to_hub_debug.yml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Deploy to debug app
|
2 |
+
|
3 |
+
on:
|
4 |
+
# to run this workflow manually from the Actions tab
|
5 |
+
workflow_dispatch:
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
sync-to-hub-debug:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
steps:
|
11 |
+
- uses: actions/checkout@v2
|
12 |
+
with:
|
13 |
+
fetch-depth: 0
|
14 |
+
- name: Push to hub
|
15 |
+
env:
|
16 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
17 |
+
run: git push --force https://boris:[email protected]/spaces/flax-community/dalle-mini-debug +HEAD:main
|
CITATION.cff
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# YAML 1.2
|
2 |
+
---
|
3 |
+
abstract: "DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardware resources. By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models, we were able to create a model that is 27 times smaller than the original DALL·E and train it on a single TPU v3-8 for only 3 days. DALL·E mini achieves impressive results, albeit of a lower quality than the original system. It can be used for exploration and further experimentation on commodity hardware."
|
4 |
+
authors:
|
5 |
+
-
|
6 |
+
family-names: Dayma
|
7 |
+
given-names: Boris
|
8 |
+
-
|
9 |
+
family-names: Patil
|
10 |
+
given-names: Suraj
|
11 |
+
-
|
12 |
+
family-names: Cuenca
|
13 |
+
given-names: Pedro
|
14 |
+
-
|
15 |
+
family-names: Saifullah
|
16 |
+
given-names: Khalid
|
17 |
+
-
|
18 |
+
family-names: Abraham
|
19 |
+
given-names: Tanishq
|
20 |
+
-
|
21 |
+
family-names: "Lê Khắc"
|
22 |
+
given-names: "Phúc"
|
23 |
+
-
|
24 |
+
family-names: Melas
|
25 |
+
given-names: Luke
|
26 |
+
-
|
27 |
+
family-names: Ghosh
|
28 |
+
given-names: Ritobrata
|
29 |
+
cff-version: "1.1.0"
|
30 |
+
date-released: 2021-07-29
|
31 |
+
identifiers:
|
32 |
+
keywords:
|
33 |
+
- dalle
|
34 |
+
- "text-to-image generation"
|
35 |
+
- transformer
|
36 |
+
- "zero-shot"
|
37 |
+
- JAX
|
38 |
+
license: "Apache-2.0"
|
39 |
+
doi: 10.5281/zenodo.5146400
|
40 |
+
message: "If you use this project, please cite it using these metadata."
|
41 |
+
repository-code: "https://github.com/borisdayma/dalle-mini"
|
42 |
+
title: "DALL·E Mini"
|
43 |
+
version: "v0.1-alpha"
|
44 |
+
...
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: DALL·E mini
|
3 |
emoji: 🥑
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
app_file: app/app.py
|
8 |
pinned: false
|
@@ -16,7 +16,7 @@ _Generate images from a text prompt_
|
|
16 |
|
17 |
Our logo was generated with DALL·E mini using the prompt "logo of an armchair in the shape of an avocado".
|
18 |
|
19 |
-
You can create your own pictures with [the demo](https://huggingface.co/spaces/flax-community/dalle-mini)
|
20 |
|
21 |
## How does it work?
|
22 |
|
@@ -26,8 +26,6 @@ Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini
|
|
26 |
|
27 |
### Dependencies Installation
|
28 |
|
29 |
-
The root folder and associated [`requirements.txt`](./requirements.txt) is only for the app.
|
30 |
-
|
31 |
For development, use [`dev/requirements.txt`](dev/requirements.txt) or [`dev/environment.yaml`](dev/environment.yaml).
|
32 |
|
33 |
### Training of VQGAN
|
@@ -52,7 +50,16 @@ To generate sample predictions and understand the inference pipeline step by ste
|
|
52 |
|
53 |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/dev/inference/inference_pipeline.ipynb)
|
54 |
|
55 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
|
58 |
|
@@ -70,4 +77,66 @@ The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL
|
|
70 |
## Acknowledgements
|
71 |
|
72 |
- 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
|
73 |
-
- Google Cloud
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: DALL·E mini
|
3 |
emoji: 🥑
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
sdk: streamlit
|
7 |
app_file: app/app.py
|
8 |
pinned: false
|
|
|
16 |
|
17 |
Our logo was generated with DALL·E mini using the prompt "logo of an armchair in the shape of an avocado".
|
18 |
|
19 |
+
You can create your own pictures with [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).
|
20 |
|
21 |
## How does it work?
|
22 |
|
|
|
26 |
|
27 |
### Dependencies Installation
|
28 |
|
|
|
|
|
29 |
For development, use [`dev/requirements.txt`](dev/requirements.txt) or [`dev/environment.yaml`](dev/environment.yaml).
|
30 |
|
31 |
### Training of VQGAN
|
|
|
50 |
|
51 |
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/dev/inference/inference_pipeline.ipynb)
|
52 |
|
53 |
+
## FAQ
|
54 |
+
|
55 |
+
### Where to find the latest models?
|
56 |
+
|
57 |
+
Trained models are on 🤗 Model Hub:
|
58 |
+
|
59 |
+
- [VQGAN-f16-16384](https://huggingface.co/flax-community/vqgan_f16_16384) for encoding/decoding images
|
60 |
+
- [DALL·E mini](https://huggingface.co/flax-community/dalle-mini) for generating images from a text prompt
|
61 |
+
|
62 |
+
### Where does the logo come from?
|
63 |
|
64 |
The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone to us.
|
65 |
|
|
|
77 |
## Acknowledgements
|
78 |
|
79 |
- 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
|
80 |
+
- Google [TPU Research Cloud (TRC) program](https://sites.research.google/trc/) for providing computing resources
|
81 |
+
- [Weights & Biases](https://wandb.com/) for providing the infrastructure for experiment tracking and model management
|
82 |
+
|
83 |
+
## Citing DALL·E mini
|
84 |
+
|
85 |
+
If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
|
86 |
+
|
87 |
+
```
|
88 |
+
@misc{Dayma_DALL·E_Mini_2021,
|
89 |
+
author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
|
90 |
+
doi = {10.5281/zenodo.5146400},
|
91 |
+
month = {7},
|
92 |
+
title = {DALL·E Mini},
|
93 |
+
url = {https://github.com/borisdayma/dalle-mini},
|
94 |
+
year = {2021}
|
95 |
+
}
|
96 |
+
```
|
97 |
+
|
98 |
+
## References
|
99 |
+
|
100 |
+
```
|
101 |
+
@misc{ramesh2021zeroshot,
|
102 |
+
title={Zero-Shot Text-to-Image Generation},
|
103 |
+
author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
|
104 |
+
year={2021},
|
105 |
+
eprint={2102.12092},
|
106 |
+
archivePrefix={arXiv},
|
107 |
+
primaryClass={cs.CV}
|
108 |
+
}
|
109 |
+
```
|
110 |
+
|
111 |
+
```
|
112 |
+
@misc{esser2021taming,
|
113 |
+
title={Taming Transformers for High-Resolution Image Synthesis},
|
114 |
+
author={Patrick Esser and Robin Rombach and Björn Ommer},
|
115 |
+
year={2021},
|
116 |
+
eprint={2012.09841},
|
117 |
+
archivePrefix={arXiv},
|
118 |
+
primaryClass={cs.CV}
|
119 |
+
}
|
120 |
+
```
|
121 |
+
|
122 |
+
```
|
123 |
+
@misc{lewis2019bart,
|
124 |
+
title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
|
125 |
+
author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
|
126 |
+
year={2019},
|
127 |
+
eprint={1910.13461},
|
128 |
+
archivePrefix={arXiv},
|
129 |
+
primaryClass={cs.CL}
|
130 |
+
}
|
131 |
+
```
|
132 |
+
|
133 |
+
```
|
134 |
+
@misc{radford2021learning,
|
135 |
+
title={Learning Transferable Visual Models From Natural Language Supervision},
|
136 |
+
author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
|
137 |
+
year={2021},
|
138 |
+
eprint={2103.00020},
|
139 |
+
archivePrefix={arXiv},
|
140 |
+
primaryClass={cs.CV}
|
141 |
+
}
|
142 |
+
```
|
app/app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding: utf-8
|
3 |
|
4 |
-
import random
|
5 |
from dalle_mini.backend import ServiceError, get_images_from_backend
|
6 |
|
7 |
import streamlit as st
|
@@ -55,12 +54,31 @@ st.subheader('Generate images from text')
|
|
55 |
|
56 |
prompt = st.text_input("What do you want to see?")
|
57 |
|
58 |
-
|
59 |
-
|
60 |
DEBUG = False
|
61 |
if prompt != "" or (should_run_again and prompt != ""):
|
62 |
container = st.empty()
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
try:
|
66 |
backend_url = st.secrets["BACKEND_SERVER"]
|
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding: utf-8
|
3 |
|
|
|
4 |
from dalle_mini.backend import ServiceError, get_images_from_backend
|
5 |
|
6 |
import streamlit as st
|
|
|
54 |
|
55 |
prompt = st.text_input("What do you want to see?")
|
56 |
|
57 |
+
test = st.empty()
|
|
|
58 |
DEBUG = False
|
59 |
if prompt != "" or (should_run_again and prompt != ""):
|
60 |
container = st.empty()
|
61 |
+
# The following mimics `streamlit.info()`.
|
62 |
+
# I tried to get the secondary background color using `components.streamlit.config.get_options_for_section("theme")["secondaryBackgroundColor"]`
|
63 |
+
# but it returns None.
|
64 |
+
container.markdown(f"""
|
65 |
+
<style> p {{ margin:0 }} div {{ margin:0 }} </style>
|
66 |
+
<div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
|
67 |
+
<div class="stAlert">
|
68 |
+
<div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
|
69 |
+
<div class="st-b7">
|
70 |
+
<div class="css-whx05o e13vu3m50">
|
71 |
+
<div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
|
72 |
+
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/img/loading.gif" width="30"/>
|
73 |
+
Generating predictions for: <b>{prompt}</b>
|
74 |
+
</div>
|
75 |
+
</div>
|
76 |
+
</div>
|
77 |
+
</div>
|
78 |
+
</div>
|
79 |
+
</div>
|
80 |
+
<small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
|
81 |
+
""", unsafe_allow_html=True)
|
82 |
|
83 |
try:
|
84 |
backend_url = st.secrets["BACKEND_SERVER"]
|
app/dalle_mini
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
../dalle_mini/
|
|
|
|
app/gradio/dalle_mini
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
../../dalle_mini/
|
|
|
|
app/img/loading.gif
ADDED
dalle_mini/text.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for processing text.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import requests
|
6 |
+
from pathlib import Path
|
7 |
+
from unidecode import unidecode
|
8 |
+
|
9 |
+
import re, math, random, html
|
10 |
+
import ftfy
|
11 |
+
|
12 |
+
WIKI_STATS_URL = "https://github.com/borisdayma/wikipedia-word-frequency/raw/feat-update/results/enwiki-20210820-words-frequency.txt"
|
13 |
+
WIKI_STATS_LOCAL = Path(WIKI_STATS_URL).parts[-1]
|
14 |
+
|
15 |
+
# based on wiki word occurence
|
16 |
+
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
|
17 |
+
temp_token = "xtokx" # avoid repeating chars
|
18 |
+
|
19 |
+
|
20 |
+
def get_wiki_file():
|
21 |
+
if not Path(WIKI_STATS_LOCAL).exists():
|
22 |
+
r = requests.get(WIKI_STATS_URL, stream=True)
|
23 |
+
with open(WIKI_STATS_LOCAL, "wb") as fd:
|
24 |
+
for chunk in r.iter_content(chunk_size=128):
|
25 |
+
fd.write(chunk)
|
26 |
+
return WIKI_STATS_LOCAL
|
27 |
+
|
28 |
+
|
29 |
+
class HashtagProcessor:
|
30 |
+
# Adapted from wordninja library
|
31 |
+
# We use our wikipedia word count + a good heuristic to make it work
|
32 |
+
def __init__(self):
|
33 |
+
self._word_cost = (
|
34 |
+
l.split()[0] for l in Path(get_wiki_file()).read_text().splitlines()
|
35 |
+
)
|
36 |
+
self._word_cost = {
|
37 |
+
str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
|
38 |
+
}
|
39 |
+
self._max_word = max(len(x) for x in self._word_cost.keys())
|
40 |
+
self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
|
41 |
+
|
42 |
+
def __call__(self, s):
|
43 |
+
"""Uses dynamic programming to infer the location of spaces in a string without spaces."""
|
44 |
+
l = [self._split(x) for x in self._SPLIT_RE.split(s)]
|
45 |
+
return " ".join([item for sublist in l for item in sublist])
|
46 |
+
|
47 |
+
def _split(self, s):
|
48 |
+
# Find the best match for the i first characters, assuming cost has
|
49 |
+
# been built for the i-1 first characters.
|
50 |
+
# Returns a pair (match_cost, match_length).
|
51 |
+
def best_match(i):
|
52 |
+
candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
|
53 |
+
return min(
|
54 |
+
(c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
|
55 |
+
for k, c in candidates
|
56 |
+
)
|
57 |
+
|
58 |
+
# Build the cost array
|
59 |
+
cost = [0]
|
60 |
+
for i in range(1, len(s) + 1):
|
61 |
+
c, k = best_match(i)
|
62 |
+
cost.append(c)
|
63 |
+
|
64 |
+
# Backtrack to recover the minimal-cost string.
|
65 |
+
out = []
|
66 |
+
i = len(s)
|
67 |
+
while i > 0:
|
68 |
+
c, k = best_match(i)
|
69 |
+
assert c == cost[i]
|
70 |
+
newToken = True
|
71 |
+
if not s[i - k : i] == "'": # ignore a lone apostrophe
|
72 |
+
if len(out) > 0:
|
73 |
+
# re-attach split 's and split digits
|
74 |
+
if out[-1] == "'s" or (
|
75 |
+
s[i - 1].isdigit() and out[-1][0].isdigit()
|
76 |
+
): # digit followed by digit
|
77 |
+
out[-1] = (
|
78 |
+
s[i - k : i] + out[-1]
|
79 |
+
) # combine current token with previous token
|
80 |
+
newToken = False
|
81 |
+
|
82 |
+
if newToken:
|
83 |
+
out.append(s[i - k : i])
|
84 |
+
|
85 |
+
i -= k
|
86 |
+
|
87 |
+
return reversed(out)
|
88 |
+
|
89 |
+
|
90 |
+
def replace_person_token(t):
|
91 |
+
"Used for CC12M"
|
92 |
+
t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
|
93 |
+
while "<person>" in t:
|
94 |
+
t = t.replace(
|
95 |
+
"<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
|
96 |
+
)
|
97 |
+
return t
|
98 |
+
|
99 |
+
|
100 |
+
def fix_html(t):
|
101 |
+
"Adapted from fastai"
|
102 |
+
t = (
|
103 |
+
t.replace("#39;", "'")
|
104 |
+
.replace("&", "&")
|
105 |
+
.replace("amp;", "&")
|
106 |
+
.replace("#146;", "'")
|
107 |
+
.replace("nbsp;", " ")
|
108 |
+
.replace("#36;", "$")
|
109 |
+
.replace("\\n", "\n")
|
110 |
+
.replace("quot;", "'")
|
111 |
+
.replace("<br />", "\n")
|
112 |
+
.replace('\\"', '"')
|
113 |
+
.replace("<unk>", " ")
|
114 |
+
.replace(" @.@ ", ".")
|
115 |
+
.replace(" @-@ ", "-")
|
116 |
+
)
|
117 |
+
return html.unescape(t)
|
118 |
+
|
119 |
+
|
120 |
+
def replace_punctuation_with_commas(t):
|
121 |
+
return re.sub("""([()[\].,|:;?!=+~\-])""", ",", t)
|
122 |
+
|
123 |
+
|
124 |
+
def simplify_quotes(t):
|
125 |
+
return re.sub("""['"`]""", ' " ', t)
|
126 |
+
|
127 |
+
|
128 |
+
def merge_quotes(t):
|
129 |
+
return re.sub('(\s*"+\s*)+', ' " ', t)
|
130 |
+
|
131 |
+
|
132 |
+
def remove_comma_numbers(t):
|
133 |
+
def _f(t):
|
134 |
+
return re.sub("(\d),(\d{3})", r"\1\2", t)
|
135 |
+
|
136 |
+
return _f(_f(t))
|
137 |
+
|
138 |
+
|
139 |
+
def pre_process_dot_numbers(t):
|
140 |
+
return re.sub("(\d)\.(\d)", fr"\1{temp_token}dot{temp_token}\2", t)
|
141 |
+
|
142 |
+
|
143 |
+
def post_process_dot_numbers(t):
|
144 |
+
return re.sub(f"{temp_token}dot{temp_token}", ".", t)
|
145 |
+
|
146 |
+
|
147 |
+
def pre_process_quotes(t):
|
148 |
+
# allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
|
149 |
+
return re.sub(
|
150 |
+
r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", fr"{temp_token}quote{temp_token}", t
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
def post_process_quotes(t):
|
155 |
+
return re.sub(f"{temp_token}quote{temp_token}", "'", t)
|
156 |
+
|
157 |
+
|
158 |
+
def merge_commas(t):
|
159 |
+
return re.sub("(\s*,+\s*)+", ", ", t)
|
160 |
+
|
161 |
+
|
162 |
+
def add_space_after_commas(t):
|
163 |
+
return re.sub(",", ", ", t)
|
164 |
+
|
165 |
+
|
166 |
+
def handle_special_chars(t):
|
167 |
+
"Handle special characters"
|
168 |
+
# replace "-" with a space when between words without space
|
169 |
+
t = re.sub("([a-zA-Z])-([a-zA-Z])", r"\1 \2", t)
|
170 |
+
# always add space around &
|
171 |
+
return re.sub("&", " & ", t)
|
172 |
+
|
173 |
+
|
174 |
+
def expand_hashtags(t, hashtag_processor):
|
175 |
+
"Remove # and try to split words"
|
176 |
+
return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
|
177 |
+
|
178 |
+
|
179 |
+
_re_ignore_chars = """[_#\/\\%]"""
|
180 |
+
|
181 |
+
|
182 |
+
def ignore_chars(t):
|
183 |
+
"Ignore useless characters"
|
184 |
+
return re.sub(_re_ignore_chars, " ", t)
|
185 |
+
|
186 |
+
|
187 |
+
def remove_extra_spaces(t):
|
188 |
+
"Remove extra spaces (including \t and \n)"
|
189 |
+
return re.sub("\s+", " ", t)
|
190 |
+
|
191 |
+
|
192 |
+
def remove_repeating_chars(t):
|
193 |
+
"If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
|
194 |
+
return re.sub(r"(\D)(\1{3,})", r"\1", t)
|
195 |
+
|
196 |
+
|
197 |
+
def remove_urls(t):
|
198 |
+
return re.sub(r"http\S+", "", t)
|
199 |
+
|
200 |
+
|
201 |
+
def remove_html_tags(t):
|
202 |
+
return re.sub("<[^<]+?>", "", t)
|
203 |
+
|
204 |
+
|
205 |
+
def remove_first_last_commas(t):
|
206 |
+
t = t.strip()
|
207 |
+
t = t[:-1] if t and t[-1] == "," else t
|
208 |
+
t = t[1:] if t and t[0] == "," else t
|
209 |
+
return t.strip()
|
210 |
+
|
211 |
+
|
212 |
+
def remove_wiki_ref(t):
|
213 |
+
t = re.sub(r"\A\s*\[\d+\]", "", t)
|
214 |
+
return re.sub(r"\[\d+\]\s*\Z", "", t)
|
215 |
+
|
216 |
+
|
217 |
+
class TextNormalizer:
|
218 |
+
"Normalize text"
|
219 |
+
|
220 |
+
def __init__(self):
|
221 |
+
self._hashtag_processor = HashtagProcessor()
|
222 |
+
|
223 |
+
def __call__(self, t, clip=False):
|
224 |
+
|
225 |
+
# fix some characters
|
226 |
+
t = ftfy.fix_text(t)
|
227 |
+
# fix html
|
228 |
+
t = fix_html(t)
|
229 |
+
if not clip:
|
230 |
+
# decode and simplify text: see unidecode library
|
231 |
+
t = unidecode(t)
|
232 |
+
# lower case
|
233 |
+
t = t.lower()
|
234 |
+
# replace <PERSON> (for CC12M)
|
235 |
+
t = replace_person_token(t)
|
236 |
+
# remove wiki reference (for WIT)
|
237 |
+
t = remove_wiki_ref(t)
|
238 |
+
# remove html tags
|
239 |
+
t = remove_html_tags(t)
|
240 |
+
# remove urls
|
241 |
+
t = remove_urls(t)
|
242 |
+
# remove commas in numbers
|
243 |
+
t = remove_comma_numbers(t)
|
244 |
+
if not clip:
|
245 |
+
# handle dots in numbers and quotes - Part 1
|
246 |
+
t = pre_process_dot_numbers(t)
|
247 |
+
t = pre_process_quotes(t)
|
248 |
+
# handle special characters
|
249 |
+
t = handle_special_chars(t)
|
250 |
+
# handle hashtags
|
251 |
+
t = expand_hashtags(t, self._hashtag_processor)
|
252 |
+
# ignore useless characters
|
253 |
+
t = ignore_chars(t)
|
254 |
+
# simplify quotes
|
255 |
+
t = simplify_quotes(t)
|
256 |
+
# all punctuation becomes commas
|
257 |
+
t = replace_punctuation_with_commas(t)
|
258 |
+
# handle dots in numbers and quotes - Part 2
|
259 |
+
t = post_process_dot_numbers(t)
|
260 |
+
t = post_process_quotes(t)
|
261 |
+
# handle repeating characters
|
262 |
+
t = remove_repeating_chars(t)
|
263 |
+
# merge commas
|
264 |
+
t = merge_commas(t)
|
265 |
+
# merge quotes
|
266 |
+
t = merge_quotes(t)
|
267 |
+
# remove multiple spaces
|
268 |
+
t = remove_extra_spaces(t)
|
269 |
+
# remove first and last comma
|
270 |
+
t = remove_first_last_commas(t)
|
271 |
+
# always start with a space
|
272 |
+
return f" {t}" if not clip else t
|
dev/README.md
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Development Instructions for TPU
|
2 |
+
|
3 |
+
## Setup
|
4 |
+
|
5 |
+
- Apply to the [TRC program](https://sites.research.google/trc/) for free TPU credits if you're elligible.
|
6 |
+
- Follow the [Cloud TPU VM User's Guide](https://cloud.google.com/tpu/docs/users-guide-tpu-vm) to set up gcloud.
|
7 |
+
- Verify `gcloud config list`, in particular account, project & zone.
|
8 |
+
- Create a TPU VM per the guide and connect to it.
|
9 |
+
|
10 |
+
When needing a larger disk:
|
11 |
+
|
12 |
+
- Create a balanced persistent disk (SSD, so pricier than default HDD but much faster): `gcloud compute disks create DISK_NAME --size SIZE_IN_GB --type pd-balanced`
|
13 |
+
- Attach the disk to your instance by adding `--data-disk source=REF` per ["Adding a persistent disk to a TPU VM" guide](https://cloud.google.com/tpu/docs/setup-persistent-disk), eg `gcloud alpha compute tpus tpu-vm create INSTANCE_NAME --accelerator-type=v3-8 --version=v2-alpha --data-disk source=projects/tpu-toys/zones/europe-west4-a/disks/DISK_NAME`
|
14 |
+
- Format the partition as described in the guide.
|
15 |
+
- Make sure to set up automatic remount of disk at restart.
|
16 |
+
|
17 |
+
## Connect VS Code
|
18 |
+
|
19 |
+
- Find external IP in the UI or with `gcloud alpha compute tpus tpu-vm describe INSTANCE_NAME`
|
20 |
+
- Verify you can connect in terminal with `ssh EXTERNAL_IP -i ~/.ssh/google_compute_engine`
|
21 |
+
- Add the same command as ssh host in VS Code.
|
22 |
+
- Check config file
|
23 |
+
|
24 |
+
```
|
25 |
+
Host INSTANCE_NAME
|
26 |
+
HostName EXTERNAL_IP
|
27 |
+
IdentityFile ~/.ssh/google_compute_engine
|
28 |
+
```
|
29 |
+
|
30 |
+
## Environment configuration
|
31 |
+
|
32 |
+
### Use virtual environments (optional)
|
33 |
+
|
34 |
+
We recommend using virtual environments (such as conda, venv or pyenv-virtualenv).
|
35 |
+
|
36 |
+
If you want to use `pyenv` and `pyenv-virtualenv`:
|
37 |
+
|
38 |
+
- Installation
|
39 |
+
|
40 |
+
- [Set up build environment](https://github.com/pyenv/pyenv/wiki#suggested-build-environment)
|
41 |
+
- Use [pyenv-installer](https://github.com/pyenv/pyenv-installer): `curl https://pyenv.run | bash`
|
42 |
+
- bash set-up:
|
43 |
+
|
44 |
+
```bash
|
45 |
+
echo '\n'\
|
46 |
+
'# pyenv setup \n'\
|
47 |
+
'export PYENV_ROOT="$HOME/.pyenv" \n'\
|
48 |
+
'export PATH="$PYENV_ROOT/bin:$PATH" \n'\
|
49 |
+
'eval "$(pyenv init --path)" \n'\
|
50 |
+
'eval "$(pyenv init -)" \n'\
|
51 |
+
'eval "$(pyenv virtualenv-init -)"' >> ~/.bashrc
|
52 |
+
```
|
53 |
+
|
54 |
+
- Usage
|
55 |
+
|
56 |
+
- Install a python version: `pyenv install X.X.X`
|
57 |
+
- Create a virtual environment: `pyenv virtualenv 3.9.6 dalle_env`
|
58 |
+
- Activate: `pyenv activate dalle_env`
|
59 |
+
|
60 |
+
Note: you can auto-activate your environment at a location with `echo dalle_env >> .python-version`
|
61 |
+
|
62 |
+
### Tools
|
63 |
+
|
64 |
+
- Git
|
65 |
+
|
66 |
+
- `git config --global user.email "[email protected]"
|
67 |
+
- `git config --global user.name "First Last"
|
68 |
+
|
69 |
+
- Github CLI
|
70 |
+
|
71 |
+
- See [installation instructions](https://github.com/cli/cli/blob/trunk/docs/install_linux.md)
|
72 |
+
- `gh auth login`
|
73 |
+
|
74 |
+
- Direnv
|
75 |
+
|
76 |
+
- Install direnv: `sudo apt-get update && sudo apt-get install direnv`
|
77 |
+
- bash set-up:
|
78 |
+
|
79 |
+
```bash
|
80 |
+
echo -e '\n'\
|
81 |
+
'# direnv setup \n'\
|
82 |
+
'eval "$(direnv hook bash)" \n' >> ~/.bashrc
|
83 |
+
```
|
84 |
+
|
85 |
+
### Set up repo
|
86 |
+
|
87 |
+
- Clone repo: `gh repo clone borisdayma/dalle-mini`
|
88 |
+
- If using `pyenv-virtualenv`, auto-activate env: `echo dalle_env >> .python-version`
|
89 |
+
|
90 |
+
## Environment
|
91 |
+
|
92 |
+
- Install the following (use it later to update our dev requirements.txt)
|
93 |
+
|
94 |
+
```
|
95 |
+
requests
|
96 |
+
pillow
|
97 |
+
jupyterlab
|
98 |
+
ipywidgets
|
99 |
+
|
100 |
+
-e ../datasets[streaming]
|
101 |
+
-e ../transformers
|
102 |
+
-e ../webdataset
|
103 |
+
|
104 |
+
# JAX
|
105 |
+
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
106 |
+
jax[tpu]>=0.2.16
|
107 |
+
flax
|
108 |
+
```
|
109 |
+
|
110 |
+
- `transformers-cli login`
|
111 |
+
|
112 |
+
---
|
113 |
+
|
114 |
+
- set `HF_HOME="/mnt/disks/persist/cache/huggingface"` in `/etc/environment` and ensure you have required permissions, then restart.
|
115 |
+
|
116 |
+
## Working with datasets or models
|
117 |
+
|
118 |
+
- Install [Git LFS](https://github.com/git-lfs/git-lfs/wiki/Installation)
|
119 |
+
- Clone a dataset without large files: `GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/.../...`
|
120 |
+
- Use a local [credential store](https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage) for caching credentials
|
121 |
+
- Track specific extentions: `git lfs track "*.ext"`
|
122 |
+
- See files tracked with LFS with `git lfs ls-files`
|
dev/encoding/vqgan-jax-encoding-streaming.ipynb
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "d0b72877",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# VQGAN JAX Encoding for 🤗 Datasets in streaming mode"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "markdown",
|
13 |
+
"id": "ba7b31e6",
|
14 |
+
"metadata": {},
|
15 |
+
"source": [
|
16 |
+
"This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and 🤗 Datasets in streaming mode.\n",
|
17 |
+
"\n",
|
18 |
+
"This example uses our YFCC100M dataset, but it should be easy to adapt to any other image/caption dataset in the huggingface hub."
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": null,
|
24 |
+
"id": "3b59489e",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"import io\n",
|
29 |
+
"\n",
|
30 |
+
"import requests\n",
|
31 |
+
"from PIL import Image\n",
|
32 |
+
"import numpy as np\n",
|
33 |
+
"from tqdm import tqdm\n",
|
34 |
+
"\n",
|
35 |
+
"import torch\n",
|
36 |
+
"import torchvision.transforms as T\n",
|
37 |
+
"import torchvision.transforms.functional as TF\n",
|
38 |
+
"from torchvision.transforms import InterpolationMode\n",
|
39 |
+
"import os\n",
|
40 |
+
"\n",
|
41 |
+
"import jax\n",
|
42 |
+
"from jax import pmap"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "markdown",
|
47 |
+
"id": "c7c4c1e6",
|
48 |
+
"metadata": {},
|
49 |
+
"source": [
|
50 |
+
"## Dataset and Parameters"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": null,
|
56 |
+
"id": "d45a289e",
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"import datasets\n",
|
61 |
+
"from datasets import Dataset, load_dataset"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "markdown",
|
66 |
+
"id": "f26e4f18",
|
67 |
+
"metadata": {},
|
68 |
+
"source": [
|
69 |
+
"We'll use the `validation` set for testing. Adjust accordingly."
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": null,
|
75 |
+
"id": "28893c3e",
|
76 |
+
"metadata": {},
|
77 |
+
"outputs": [],
|
78 |
+
"source": [
|
79 |
+
"dataset = load_dataset('dalle-mini/YFCC100M_OpenAI_subset', use_auth_token=True, streaming=True, split='validation')"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": null,
|
85 |
+
"id": "33861477",
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [],
|
88 |
+
"source": [
|
89 |
+
"from pathlib import Path\n",
|
90 |
+
"\n",
|
91 |
+
"yfcc100m = Path.home()/'data'/'YFCC100M_OpenAI_subset'\n",
|
92 |
+
"yfcc100m_output = yfcc100m/'encoded' # Output directory for encoded files"
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "code",
|
97 |
+
"execution_count": null,
|
98 |
+
"id": "6e7b71c4",
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [],
|
101 |
+
"source": [
|
102 |
+
"batch_size = 128 # Per device\n",
|
103 |
+
"num_workers = 16 # Unused in streaming mode"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "markdown",
|
108 |
+
"id": "0793c26a",
|
109 |
+
"metadata": {},
|
110 |
+
"source": [
|
111 |
+
"### Data preparation"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "markdown",
|
116 |
+
"id": "86415769",
|
117 |
+
"metadata": {},
|
118 |
+
"source": [
|
119 |
+
"* Images: we transform them so they are center-cropped and square, all of the same size so we can build batches for TPU/GPU processing.\n",
|
120 |
+
"* Captions: we extract a single `caption` column from the source data, by concatenating the cleaned title and description.\n",
|
121 |
+
"\n",
|
122 |
+
"These transformations are done using the Datasets `map` function. In the case of streaming datasets, transformations will run as needed instead of pre-processing the dataset at once."
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "markdown",
|
127 |
+
"id": "0fdf1851",
|
128 |
+
"metadata": {},
|
129 |
+
"source": [
|
130 |
+
"This helper function is used to decode images from the bytes retrieved in `streaming` mode."
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": null,
|
136 |
+
"id": "5bbca804",
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"from PIL import Image\n",
|
141 |
+
"import io\n",
|
142 |
+
"\n",
|
143 |
+
"def get_image(byte_stream):\n",
|
144 |
+
" image = Image.open(io.BytesIO(byte_stream))\n",
|
145 |
+
" return image.convert('RGB')"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"cell_type": "markdown",
|
150 |
+
"id": "b435290b",
|
151 |
+
"metadata": {},
|
152 |
+
"source": [
|
153 |
+
"Image processing"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"execution_count": null,
|
159 |
+
"id": "7e73dfa3",
|
160 |
+
"metadata": {},
|
161 |
+
"outputs": [],
|
162 |
+
"source": [
|
163 |
+
"def center_crop(image, max_size=256):\n",
|
164 |
+
" # Note: we allow upscaling too. We should exclude small images. \n",
|
165 |
+
" image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
|
166 |
+
" image = TF.center_crop(image, output_size=2 * [max_size])\n",
|
167 |
+
" return image\n",
|
168 |
+
"\n",
|
169 |
+
"preprocess_image = T.Compose([\n",
|
170 |
+
" get_image,\n",
|
171 |
+
" center_crop,\n",
|
172 |
+
" T.ToTensor(),\n",
|
173 |
+
" lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n",
|
174 |
+
"])"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "markdown",
|
179 |
+
"id": "1e3ac8de",
|
180 |
+
"metadata": {},
|
181 |
+
"source": [
|
182 |
+
"Caption preparation"
|
183 |
+
]
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"cell_type": "code",
|
187 |
+
"execution_count": null,
|
188 |
+
"id": "aadb4d23",
|
189 |
+
"metadata": {},
|
190 |
+
"outputs": [],
|
191 |
+
"source": [
|
192 |
+
"import string\n",
|
193 |
+
"\n",
|
194 |
+
"def create_caption(title, description):\n",
|
195 |
+
" title = title.strip()\n",
|
196 |
+
" description = description.strip()\n",
|
197 |
+
" if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
|
198 |
+
" return f'{title} {description}'"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "markdown",
|
203 |
+
"id": "3c4522b9",
|
204 |
+
"metadata": {},
|
205 |
+
"source": [
|
206 |
+
"And this is the basic transformation function to use in `map`. We don't really need the `key`, but we'll keep it for reference. Since we are returning a new dictionary (as opposed to adding entries to the input), this also removes any metadata columns we don't need."
|
207 |
+
]
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"execution_count": null,
|
212 |
+
"id": "2566ff68",
|
213 |
+
"metadata": {},
|
214 |
+
"outputs": [],
|
215 |
+
"source": [
|
216 |
+
"def prepare_item(item):\n",
|
217 |
+
" return {\n",
|
218 |
+
" 'key': item['key'],\n",
|
219 |
+
" 'caption': create_caption(item['title_clean'], item['description_clean']),\n",
|
220 |
+
" 'image': preprocess_image(item['img'])\n",
|
221 |
+
" }"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"cell_type": "markdown",
|
226 |
+
"id": "e519e475",
|
227 |
+
"metadata": {},
|
228 |
+
"source": [
|
229 |
+
"Unlike when using non-streaming datasets, the following operation completes immediately in streaming mode. In streaming mode, `num_proc` is not supported."
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": null,
|
235 |
+
"id": "10d7750e",
|
236 |
+
"metadata": {},
|
237 |
+
"outputs": [],
|
238 |
+
"source": [
|
239 |
+
"prepared_dataset = dataset.map(prepare_item, batched=False)"
|
240 |
+
]
|
241 |
+
},
|
242 |
+
{
|
243 |
+
"cell_type": "code",
|
244 |
+
"execution_count": null,
|
245 |
+
"id": "a8595539",
|
246 |
+
"metadata": {},
|
247 |
+
"outputs": [],
|
248 |
+
"source": [
|
249 |
+
"%%time\n",
|
250 |
+
"item = next(iter(prepared_dataset))"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": null,
|
256 |
+
"id": "04a6eeb4",
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [],
|
259 |
+
"source": [
|
260 |
+
"assert(list(item.keys()) == ['key', 'caption', 'image'])"
|
261 |
+
]
|
262 |
+
},
|
263 |
+
{
|
264 |
+
"cell_type": "code",
|
265 |
+
"execution_count": null,
|
266 |
+
"id": "40d3115f",
|
267 |
+
"metadata": {},
|
268 |
+
"outputs": [],
|
269 |
+
"source": [
|
270 |
+
"item['image'].shape"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": null,
|
276 |
+
"id": "dd844e1c",
|
277 |
+
"metadata": {},
|
278 |
+
"outputs": [],
|
279 |
+
"source": [
|
280 |
+
"T.ToPILImage()(item['image'].permute(2, 0, 1))"
|
281 |
+
]
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"cell_type": "markdown",
|
285 |
+
"id": "44d50a51",
|
286 |
+
"metadata": {},
|
287 |
+
"source": [
|
288 |
+
"### Torch DataLoader"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "markdown",
|
293 |
+
"id": "17a4bbc6",
|
294 |
+
"metadata": {},
|
295 |
+
"source": [
|
296 |
+
"We'll create a PyTorch DataLoader for convenience. This allows us to easily take batches of our desired size.\n",
|
297 |
+
"\n",
|
298 |
+
"We won't be using parallel processing of the DataLoader for now, as the items will be retrieved on the fly. We could attempt to do it using these recommendations: https://pytorch.org/docs/stable/data.html#multi-process-data-loading. For performance considerations, please refer to this thread: https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"execution_count": null,
|
304 |
+
"id": "e1c08b7e",
|
305 |
+
"metadata": {},
|
306 |
+
"outputs": [],
|
307 |
+
"source": [
|
308 |
+
"import torch\n",
|
309 |
+
"from torch.utils.data import DataLoader"
|
310 |
+
]
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"cell_type": "code",
|
314 |
+
"execution_count": null,
|
315 |
+
"id": "6a296677",
|
316 |
+
"metadata": {},
|
317 |
+
"outputs": [],
|
318 |
+
"source": [
|
319 |
+
"torch_dataset = prepared_dataset.with_format(\"torch\")"
|
320 |
+
]
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"cell_type": "markdown",
|
324 |
+
"id": "29ab13bc",
|
325 |
+
"metadata": {},
|
326 |
+
"source": [
|
327 |
+
"**Note**: according to my tests, `num_workers` is not compatible with Datasets in streaming mode. Processes deadlock and there's no progress."
|
328 |
+
]
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"cell_type": "code",
|
332 |
+
"execution_count": null,
|
333 |
+
"id": "e2df5e13",
|
334 |
+
"metadata": {},
|
335 |
+
"outputs": [],
|
336 |
+
"source": [
|
337 |
+
"dataloader = DataLoader(torch_dataset, batch_size=batch_size * jax.device_count())"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "code",
|
342 |
+
"execution_count": null,
|
343 |
+
"id": "c15e3783",
|
344 |
+
"metadata": {},
|
345 |
+
"outputs": [],
|
346 |
+
"source": [
|
347 |
+
"batch = next(iter(dataloader))"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"cell_type": "code",
|
352 |
+
"execution_count": null,
|
353 |
+
"id": "71d027fe",
|
354 |
+
"metadata": {},
|
355 |
+
"outputs": [],
|
356 |
+
"source": [
|
357 |
+
"batch['image'].shape"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"cell_type": "markdown",
|
362 |
+
"id": "a354472b",
|
363 |
+
"metadata": {},
|
364 |
+
"source": [
|
365 |
+
"## VQGAN-JAX model"
|
366 |
+
]
|
367 |
+
},
|
368 |
+
{
|
369 |
+
"cell_type": "code",
|
370 |
+
"execution_count": null,
|
371 |
+
"id": "2fcf01d7",
|
372 |
+
"metadata": {},
|
373 |
+
"outputs": [],
|
374 |
+
"source": [
|
375 |
+
"from vqgan_jax.modeling_flax_vqgan import VQModel"
|
376 |
+
]
|
377 |
+
},
|
378 |
+
{
|
379 |
+
"cell_type": "markdown",
|
380 |
+
"id": "9daa636d",
|
381 |
+
"metadata": {},
|
382 |
+
"source": [
|
383 |
+
"We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
|
384 |
+
]
|
385 |
+
},
|
386 |
+
{
|
387 |
+
"cell_type": "code",
|
388 |
+
"execution_count": null,
|
389 |
+
"id": "47a8b818",
|
390 |
+
"metadata": {
|
391 |
+
"scrolled": true
|
392 |
+
},
|
393 |
+
"outputs": [],
|
394 |
+
"source": [
|
395 |
+
"model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
|
396 |
+
]
|
397 |
+
},
|
398 |
+
{
|
399 |
+
"cell_type": "markdown",
|
400 |
+
"id": "62ad01c3",
|
401 |
+
"metadata": {},
|
402 |
+
"source": [
|
403 |
+
"## Encoding"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"cell_type": "markdown",
|
408 |
+
"id": "20357f74",
|
409 |
+
"metadata": {},
|
410 |
+
"source": [
|
411 |
+
"Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
|
412 |
+
]
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"cell_type": "code",
|
416 |
+
"execution_count": null,
|
417 |
+
"id": "6686b004",
|
418 |
+
"metadata": {},
|
419 |
+
"outputs": [],
|
420 |
+
"source": [
|
421 |
+
"from flax.training.common_utils import shard\n",
|
422 |
+
"from functools import partial"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "code",
|
427 |
+
"execution_count": null,
|
428 |
+
"id": "322a4619",
|
429 |
+
"metadata": {},
|
430 |
+
"outputs": [],
|
431 |
+
"source": [
|
432 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
433 |
+
"def encode(batch):\n",
|
434 |
+
" # Not sure if we should `replicate` params, does not seem to have any effect\n",
|
435 |
+
" _, indices = model.encode(batch)\n",
|
436 |
+
" return indices"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"cell_type": "markdown",
|
441 |
+
"id": "14375a41",
|
442 |
+
"metadata": {},
|
443 |
+
"source": [
|
444 |
+
"### Encoding loop"
|
445 |
+
]
|
446 |
+
},
|
447 |
+
{
|
448 |
+
"cell_type": "code",
|
449 |
+
"execution_count": null,
|
450 |
+
"id": "ff6c10d4",
|
451 |
+
"metadata": {},
|
452 |
+
"outputs": [],
|
453 |
+
"source": [
|
454 |
+
"import os\n",
|
455 |
+
"import pandas as pd\n",
|
456 |
+
"\n",
|
457 |
+
"def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
|
458 |
+
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
459 |
+
" \n",
|
460 |
+
" # Saving strategy:\n",
|
461 |
+
" # - Create a new file every so often to prevent excessive file seeking.\n",
|
462 |
+
" # - Save each batch after processing.\n",
|
463 |
+
" # - Keep the file open until we are done with it.\n",
|
464 |
+
" file = None \n",
|
465 |
+
" for n, batch in enumerate(tqdm(iter(dataloader))):\n",
|
466 |
+
" if (n % save_every == 0):\n",
|
467 |
+
" if file is not None:\n",
|
468 |
+
" file.close()\n",
|
469 |
+
" split_num = n // save_every\n",
|
470 |
+
" file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
|
471 |
+
"\n",
|
472 |
+
" images = batch[\"image\"].numpy()\n",
|
473 |
+
" images = shard(images.squeeze())\n",
|
474 |
+
" encoded = encode(images)\n",
|
475 |
+
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
476 |
+
"\n",
|
477 |
+
" keys = batch[\"key\"]\n",
|
478 |
+
" captions = batch[\"caption\"]\n",
|
479 |
+
"\n",
|
480 |
+
" encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
|
481 |
+
" batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
|
482 |
+
" batch_df.to_json(file, orient='records', lines=True)"
|
483 |
+
]
|
484 |
+
},
|
485 |
+
{
|
486 |
+
"cell_type": "markdown",
|
487 |
+
"id": "09ff75a3",
|
488 |
+
"metadata": {},
|
489 |
+
"source": [
|
490 |
+
"Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
|
491 |
+
]
|
492 |
+
},
|
493 |
+
{
|
494 |
+
"cell_type": "code",
|
495 |
+
"execution_count": null,
|
496 |
+
"id": "96222bb4",
|
497 |
+
"metadata": {},
|
498 |
+
"outputs": [],
|
499 |
+
"source": [
|
500 |
+
"save_every = 318"
|
501 |
+
]
|
502 |
+
},
|
503 |
+
{
|
504 |
+
"cell_type": "code",
|
505 |
+
"execution_count": null,
|
506 |
+
"id": "7704863d",
|
507 |
+
"metadata": {},
|
508 |
+
"outputs": [
|
509 |
+
{
|
510 |
+
"name": "stderr",
|
511 |
+
"output_type": "stream",
|
512 |
+
"text": [
|
513 |
+
"28it [01:17, 1.60s/it]"
|
514 |
+
]
|
515 |
+
}
|
516 |
+
],
|
517 |
+
"source": [
|
518 |
+
"encode_captioned_dataset(dataloader, yfcc100m_output, save_every=save_every)"
|
519 |
+
]
|
520 |
+
},
|
521 |
+
{
|
522 |
+
"cell_type": "markdown",
|
523 |
+
"id": "e266a70a",
|
524 |
+
"metadata": {},
|
525 |
+
"source": [
|
526 |
+
"This is ~10-15 slower than local encoding from an SSD. For performance considerations, see the discussion at https://discuss.huggingface.co/t/allow-streaming-of-large-datasets-with-image-audio/8062/13."
|
527 |
+
]
|
528 |
+
},
|
529 |
+
{
|
530 |
+
"cell_type": "markdown",
|
531 |
+
"id": "8953dd84",
|
532 |
+
"metadata": {},
|
533 |
+
"source": [
|
534 |
+
"----"
|
535 |
+
]
|
536 |
+
}
|
537 |
+
],
|
538 |
+
"metadata": {
|
539 |
+
"interpreter": {
|
540 |
+
"hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
|
541 |
+
},
|
542 |
+
"kernelspec": {
|
543 |
+
"display_name": "Python 3 (ipykernel)",
|
544 |
+
"language": "python",
|
545 |
+
"name": "python3"
|
546 |
+
},
|
547 |
+
"language_info": {
|
548 |
+
"codemirror_mode": {
|
549 |
+
"name": "ipython",
|
550 |
+
"version": 3
|
551 |
+
},
|
552 |
+
"file_extension": ".py",
|
553 |
+
"mimetype": "text/x-python",
|
554 |
+
"name": "python",
|
555 |
+
"nbconvert_exporter": "python",
|
556 |
+
"pygments_lexer": "ipython3",
|
557 |
+
"version": "3.8.10"
|
558 |
+
}
|
559 |
+
},
|
560 |
+
"nbformat": 4,
|
561 |
+
"nbformat_minor": 5
|
562 |
+
}
|
dev/encoding/vqgan-jax-encoding-webdataset.ipynb
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "d0b72877",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# VQGAN JAX Encoding for `webdataset`"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "markdown",
|
13 |
+
"id": "ba7b31e6",
|
14 |
+
"metadata": {},
|
15 |
+
"source": [
|
16 |
+
"This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
|
17 |
+
"\n",
|
18 |
+
"This example uses a small subset of YFCC100M we created for testing, but it should be easy to adapt to any other image/caption dataset in the `webdataset` format."
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": null,
|
24 |
+
"id": "3b59489e",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"import numpy as np\n",
|
29 |
+
"from tqdm import tqdm\n",
|
30 |
+
"\n",
|
31 |
+
"import torch\n",
|
32 |
+
"import torchvision.transforms as T\n",
|
33 |
+
"import torchvision.transforms.functional as TF\n",
|
34 |
+
"from torchvision.transforms import InterpolationMode\n",
|
35 |
+
"import math\n",
|
36 |
+
"\n",
|
37 |
+
"import webdataset as wds\n",
|
38 |
+
"\n",
|
39 |
+
"import jax\n",
|
40 |
+
"from jax import pmap"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"id": "c7c4c1e6",
|
46 |
+
"metadata": {},
|
47 |
+
"source": [
|
48 |
+
"## Dataset and Parameters"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "markdown",
|
53 |
+
"id": "9822850f",
|
54 |
+
"metadata": {},
|
55 |
+
"source": [
|
56 |
+
"The following is the list of shards we'll process. We hardcode the length of data so that we can see nice progress bars using `tqdm`."
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": null,
|
62 |
+
"id": "1265dbfe",
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [],
|
65 |
+
"source": [
|
66 |
+
"shards = 'https://huggingface.co/datasets/dalle-mini/YFCC100M_OpenAI_subset/resolve/main/data/shard-{0000..0008}.tar'\n",
|
67 |
+
"length = 8320"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"cell_type": "markdown",
|
72 |
+
"id": "7e38fa14",
|
73 |
+
"metadata": {},
|
74 |
+
"source": [
|
75 |
+
"If we are extra cautious or our server is unreliable, we can enable retries by providing a custom `curl` retrieval command:"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": null,
|
81 |
+
"id": "4c8c5960",
|
82 |
+
"metadata": {},
|
83 |
+
"outputs": [],
|
84 |
+
"source": [
|
85 |
+
"# Enable curl retries to try to work around temporary network / server errors.\n",
|
86 |
+
"# This shouldn't be necessary when using reliable servers.\n",
|
87 |
+
"# shards = f'pipe:curl -s --retry 5 --retry-delay 5 -L {shards} || true'"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": null,
|
93 |
+
"id": "13c6631b",
|
94 |
+
"metadata": {},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"from pathlib import Path\n",
|
98 |
+
"\n",
|
99 |
+
"# Output directory for encoded files\n",
|
100 |
+
"encoded_output = Path.home()/'data'/'wds'/'encoded'\n",
|
101 |
+
"\n",
|
102 |
+
"batch_size = 128 # Per device\n",
|
103 |
+
"num_workers = 8 # For parallel processing"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": null,
|
109 |
+
"id": "3435fb85",
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
|
114 |
+
"batches = math.ceil(length / bs)"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "markdown",
|
119 |
+
"id": "88598e4b",
|
120 |
+
"metadata": {},
|
121 |
+
"source": [
|
122 |
+
"Image processing"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": null,
|
128 |
+
"id": "669b35df",
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [],
|
131 |
+
"source": [
|
132 |
+
"def center_crop(image, max_size=256):\n",
|
133 |
+
" # Note: we allow upscaling too. We should exclude small images. \n",
|
134 |
+
" image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
|
135 |
+
" image = TF.center_crop(image, output_size=2 * [max_size])\n",
|
136 |
+
" return image\n",
|
137 |
+
"\n",
|
138 |
+
"preprocess_image = T.Compose([\n",
|
139 |
+
" center_crop,\n",
|
140 |
+
" T.ToTensor(),\n",
|
141 |
+
" lambda t: t.permute(1, 2, 0) # Reorder, we need dimensions last\n",
|
142 |
+
"])"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "markdown",
|
147 |
+
"id": "a185e90c",
|
148 |
+
"metadata": {},
|
149 |
+
"source": [
|
150 |
+
"Caption preparation.\n",
|
151 |
+
"\n",
|
152 |
+
"Note that we receive the contents of the `json` structure, which will be replaced by the string we return.\n",
|
153 |
+
"If we want to keep other fields inside `json`, we can add `caption` as a new field."
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"execution_count": null,
|
159 |
+
"id": "423ee10e",
|
160 |
+
"metadata": {},
|
161 |
+
"outputs": [],
|
162 |
+
"source": [
|
163 |
+
"def create_caption(item):\n",
|
164 |
+
" title = item['title_clean'].strip()\n",
|
165 |
+
" description = item['description_clean'].strip()\n",
|
166 |
+
" if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
|
167 |
+
" return f'{title} {description}'"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "markdown",
|
172 |
+
"id": "8d3a95db",
|
173 |
+
"metadata": {},
|
174 |
+
"source": [
|
175 |
+
"When an error occurs (a download is disconnected, an image cannot be decoded, etc) the process stops with an exception. We can use one of the exception handlers provided by the `webdataset` library, such as `wds.warn_and_continue` or `wds.ignore_and_continue` to ignore the offending entry and keep iterating.\n",
|
176 |
+
"\n",
|
177 |
+
"**IMPORTANT WARNING:** Do not use error handlers to ignore exceptions until you have tested that your processing pipeline works fine. Otherwise, the process will continue trying to find a valid entry, and it will consume your whole dataset without doing any work.\n",
|
178 |
+
"\n",
|
179 |
+
"We can also create our custom exception handler as demonstrated here:"
|
180 |
+
]
|
181 |
+
},
|
182 |
+
{
|
183 |
+
"cell_type": "code",
|
184 |
+
"execution_count": null,
|
185 |
+
"id": "369d9719",
|
186 |
+
"metadata": {},
|
187 |
+
"outputs": [],
|
188 |
+
"source": [
|
189 |
+
"# UNUSED - Log exceptions to a file\n",
|
190 |
+
"def ignore_and_log(exn):\n",
|
191 |
+
" with open('errors.txt', 'a') as f:\n",
|
192 |
+
" f.write(f'{repr(exn)}\\n')\n",
|
193 |
+
" return True"
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": null,
|
199 |
+
"id": "27de1414",
|
200 |
+
"metadata": {},
|
201 |
+
"outputs": [],
|
202 |
+
"source": [
|
203 |
+
"# Or simply use `wds.ignore_and_continue`\n",
|
204 |
+
"exception_handler = wds.warn_and_continue"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": null,
|
210 |
+
"id": "5149b6d5",
|
211 |
+
"metadata": {},
|
212 |
+
"outputs": [],
|
213 |
+
"source": [
|
214 |
+
"dataset = wds.WebDataset(shards,\n",
|
215 |
+
" length=batches, # Hint so `len` is implemented\n",
|
216 |
+
" shardshuffle=False, # Keep same order for encoded files for easier bookkeeping. Set to `True` for training.\n",
|
217 |
+
" handler=exception_handler, # Ignore read errors instead of failing.\n",
|
218 |
+
")\n",
|
219 |
+
"\n",
|
220 |
+
"dataset = (dataset \n",
|
221 |
+
" .decode('pil') # decode image with PIL\n",
|
222 |
+
"# .map_dict(jpg=preprocess_image, json=create_caption, handler=exception_handler) # Process fields with functions defined above\n",
|
223 |
+
" .map_dict(jpg=preprocess_image, json=create_caption) # Process fields with functions defined above\n",
|
224 |
+
" .to_tuple('__key__', 'jpg', 'json') # filter to keep only key (for reference), image, caption.\n",
|
225 |
+
" .batched(bs)) # better to batch in the dataset (but we could also do it in the dataloader) - this arg does not affect speed and we could remove it"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": null,
|
231 |
+
"id": "8cac98cb",
|
232 |
+
"metadata": {
|
233 |
+
"scrolled": true
|
234 |
+
},
|
235 |
+
"outputs": [],
|
236 |
+
"source": [
|
237 |
+
"%%time\n",
|
238 |
+
"keys, images, captions = next(iter(dataset))"
|
239 |
+
]
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"cell_type": "code",
|
243 |
+
"execution_count": null,
|
244 |
+
"id": "cd268fbf",
|
245 |
+
"metadata": {},
|
246 |
+
"outputs": [],
|
247 |
+
"source": [
|
248 |
+
"images.shape"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "code",
|
253 |
+
"execution_count": null,
|
254 |
+
"id": "c24693c0",
|
255 |
+
"metadata": {},
|
256 |
+
"outputs": [],
|
257 |
+
"source": [
|
258 |
+
"T.ToPILImage()(images[0].permute(2, 0, 1))"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "markdown",
|
263 |
+
"id": "44d50a51",
|
264 |
+
"metadata": {},
|
265 |
+
"source": [
|
266 |
+
"### Torch DataLoader"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"execution_count": null,
|
272 |
+
"id": "e2df5e13",
|
273 |
+
"metadata": {},
|
274 |
+
"outputs": [],
|
275 |
+
"source": [
|
276 |
+
"dl = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=num_workers)"
|
277 |
+
]
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"cell_type": "markdown",
|
281 |
+
"id": "a354472b",
|
282 |
+
"metadata": {},
|
283 |
+
"source": [
|
284 |
+
"## VQGAN-JAX model"
|
285 |
+
]
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"cell_type": "code",
|
289 |
+
"execution_count": null,
|
290 |
+
"id": "2fcf01d7",
|
291 |
+
"metadata": {},
|
292 |
+
"outputs": [],
|
293 |
+
"source": [
|
294 |
+
"from vqgan_jax.modeling_flax_vqgan import VQModel"
|
295 |
+
]
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"cell_type": "markdown",
|
299 |
+
"id": "9daa636d",
|
300 |
+
"metadata": {},
|
301 |
+
"source": [
|
302 |
+
"We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"cell_type": "code",
|
307 |
+
"execution_count": null,
|
308 |
+
"id": "47a8b818",
|
309 |
+
"metadata": {
|
310 |
+
"scrolled": true
|
311 |
+
},
|
312 |
+
"outputs": [],
|
313 |
+
"source": [
|
314 |
+
"model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
|
315 |
+
]
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"cell_type": "markdown",
|
319 |
+
"id": "62ad01c3",
|
320 |
+
"metadata": {},
|
321 |
+
"source": [
|
322 |
+
"## Encoding"
|
323 |
+
]
|
324 |
+
},
|
325 |
+
{
|
326 |
+
"cell_type": "markdown",
|
327 |
+
"id": "20357f74",
|
328 |
+
"metadata": {},
|
329 |
+
"source": [
|
330 |
+
"Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
|
331 |
+
]
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "code",
|
335 |
+
"execution_count": null,
|
336 |
+
"id": "6686b004",
|
337 |
+
"metadata": {},
|
338 |
+
"outputs": [],
|
339 |
+
"source": [
|
340 |
+
"from flax.training.common_utils import shard\n",
|
341 |
+
"from functools import partial"
|
342 |
+
]
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"cell_type": "code",
|
346 |
+
"execution_count": null,
|
347 |
+
"id": "322a4619",
|
348 |
+
"metadata": {},
|
349 |
+
"outputs": [],
|
350 |
+
"source": [
|
351 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
352 |
+
"def encode(batch):\n",
|
353 |
+
" # Not sure if we should `replicate` params, does not seem to have any effect\n",
|
354 |
+
" _, indices = model.encode(batch)\n",
|
355 |
+
" return indices"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "markdown",
|
360 |
+
"id": "14375a41",
|
361 |
+
"metadata": {},
|
362 |
+
"source": [
|
363 |
+
"### Encoding loop"
|
364 |
+
]
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"cell_type": "code",
|
368 |
+
"execution_count": null,
|
369 |
+
"id": "ff6c10d4",
|
370 |
+
"metadata": {},
|
371 |
+
"outputs": [],
|
372 |
+
"source": [
|
373 |
+
"import os\n",
|
374 |
+
"import pandas as pd\n",
|
375 |
+
"\n",
|
376 |
+
"def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
|
377 |
+
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
378 |
+
"\n",
|
379 |
+
" # Saving strategy:\n",
|
380 |
+
" # - Create a new file every so often to prevent excessive file seeking.\n",
|
381 |
+
" # - Save each batch after processing.\n",
|
382 |
+
" # - Keep the file open until we are done with it.\n",
|
383 |
+
" file = None \n",
|
384 |
+
" for n, (keys, images, captions) in enumerate(tqdm(dataloader)):\n",
|
385 |
+
" if (n % save_every == 0):\n",
|
386 |
+
" if file is not None:\n",
|
387 |
+
" file.close()\n",
|
388 |
+
" split_num = n // save_every\n",
|
389 |
+
" file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
|
390 |
+
"\n",
|
391 |
+
" images = shard(images.numpy().squeeze())\n",
|
392 |
+
" encoded = encode(images)\n",
|
393 |
+
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
394 |
+
"\n",
|
395 |
+
" encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
|
396 |
+
" batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
|
397 |
+
" batch_df.to_json(file, orient='records', lines=True)"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "markdown",
|
402 |
+
"id": "09ff75a3",
|
403 |
+
"metadata": {},
|
404 |
+
"source": [
|
405 |
+
"Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
|
406 |
+
]
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"cell_type": "code",
|
410 |
+
"execution_count": null,
|
411 |
+
"id": "96222bb4",
|
412 |
+
"metadata": {},
|
413 |
+
"outputs": [],
|
414 |
+
"source": [
|
415 |
+
"save_every = 318"
|
416 |
+
]
|
417 |
+
},
|
418 |
+
{
|
419 |
+
"cell_type": "code",
|
420 |
+
"execution_count": null,
|
421 |
+
"id": "7704863d",
|
422 |
+
"metadata": {},
|
423 |
+
"outputs": [],
|
424 |
+
"source": [
|
425 |
+
"encode_captioned_dataset(dl, encoded_output, save_every=save_every)"
|
426 |
+
]
|
427 |
+
},
|
428 |
+
{
|
429 |
+
"cell_type": "markdown",
|
430 |
+
"id": "8953dd84",
|
431 |
+
"metadata": {},
|
432 |
+
"source": [
|
433 |
+
"----"
|
434 |
+
]
|
435 |
+
}
|
436 |
+
],
|
437 |
+
"metadata": {
|
438 |
+
"interpreter": {
|
439 |
+
"hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
|
440 |
+
},
|
441 |
+
"kernelspec": {
|
442 |
+
"display_name": "Python 3 (ipykernel)",
|
443 |
+
"language": "python",
|
444 |
+
"name": "python3"
|
445 |
+
},
|
446 |
+
"language_info": {
|
447 |
+
"codemirror_mode": {
|
448 |
+
"name": "ipython",
|
449 |
+
"version": 3
|
450 |
+
},
|
451 |
+
"file_extension": ".py",
|
452 |
+
"mimetype": "text/x-python",
|
453 |
+
"name": "python",
|
454 |
+
"nbconvert_exporter": "python",
|
455 |
+
"pygments_lexer": "ipython3",
|
456 |
+
"version": "3.8.10"
|
457 |
+
}
|
458 |
+
},
|
459 |
+
"nbformat": 4,
|
460 |
+
"nbformat_minor": 5
|
461 |
+
}
|
dev/inference/dalle_mini
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
../../dalle_mini
|
|
|
|
dev/inference/inference_pipeline.ipynb
CHANGED
@@ -6,7 +6,7 @@
|
|
6 |
"name": "DALL·E mini - Inference pipeline.ipynb",
|
7 |
"provenance": [],
|
8 |
"collapsed_sections": [],
|
9 |
-
"authorship_tag": "
|
10 |
"include_colab_link": true
|
11 |
},
|
12 |
"kernelspec": {
|
@@ -22,6 +22,7 @@
|
|
22 |
"49304912717a4995ae45d04a59d1f50e": {
|
23 |
"model_module": "@jupyter-widgets/controls",
|
24 |
"model_name": "HBoxModel",
|
|
|
25 |
"state": {
|
26 |
"_view_name": "HBoxView",
|
27 |
"_dom_classes": [],
|
@@ -42,6 +43,7 @@
|
|
42 |
"5fd9f97986024e8db560a6737ade9e2e": {
|
43 |
"model_module": "@jupyter-widgets/base",
|
44 |
"model_name": "LayoutModel",
|
|
|
45 |
"state": {
|
46 |
"_view_name": "LayoutView",
|
47 |
"grid_template_rows": null,
|
@@ -93,6 +95,7 @@
|
|
93 |
"caced43e3a4c493b98fb07cb41db045c": {
|
94 |
"model_module": "@jupyter-widgets/controls",
|
95 |
"model_name": "FloatProgressModel",
|
|
|
96 |
"state": {
|
97 |
"_view_name": "ProgressView",
|
98 |
"style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
|
@@ -116,6 +119,7 @@
|
|
116 |
"0acc161f2e9948b68b3fc4e57ef333c9": {
|
117 |
"model_module": "@jupyter-widgets/controls",
|
118 |
"model_name": "HTMLModel",
|
|
|
119 |
"state": {
|
120 |
"_view_name": "HTMLView",
|
121 |
"style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
|
@@ -136,6 +140,7 @@
|
|
136 |
"40c54b9454d346aabd197f2bcf189467": {
|
137 |
"model_module": "@jupyter-widgets/controls",
|
138 |
"model_name": "ProgressStyleModel",
|
|
|
139 |
"state": {
|
140 |
"_view_name": "StyleView",
|
141 |
"_model_name": "ProgressStyleModel",
|
@@ -151,6 +156,7 @@
|
|
151 |
"8b25334a48244a14aa9ba0176887e655": {
|
152 |
"model_module": "@jupyter-widgets/base",
|
153 |
"model_name": "LayoutModel",
|
|
|
154 |
"state": {
|
155 |
"_view_name": "LayoutView",
|
156 |
"grid_template_rows": null,
|
@@ -202,6 +208,7 @@
|
|
202 |
"7e7c488f57fc4acb8d261e2db81d61f0": {
|
203 |
"model_module": "@jupyter-widgets/controls",
|
204 |
"model_name": "DescriptionStyleModel",
|
|
|
205 |
"state": {
|
206 |
"_view_name": "StyleView",
|
207 |
"_model_name": "DescriptionStyleModel",
|
@@ -216,6 +223,7 @@
|
|
216 |
"72c401062a5348b1a366dffb5a403568": {
|
217 |
"model_module": "@jupyter-widgets/base",
|
218 |
"model_name": "LayoutModel",
|
|
|
219 |
"state": {
|
220 |
"_view_name": "LayoutView",
|
221 |
"grid_template_rows": null,
|
@@ -267,6 +275,7 @@
|
|
267 |
"022c124dfff348f285335732781b0887": {
|
268 |
"model_module": "@jupyter-widgets/controls",
|
269 |
"model_name": "HBoxModel",
|
|
|
270 |
"state": {
|
271 |
"_view_name": "HBoxView",
|
272 |
"_dom_classes": [],
|
@@ -287,6 +296,7 @@
|
|
287 |
"a44e47e9d26c4deb81a5a11a9db92a9f": {
|
288 |
"model_module": "@jupyter-widgets/base",
|
289 |
"model_name": "LayoutModel",
|
|
|
290 |
"state": {
|
291 |
"_view_name": "LayoutView",
|
292 |
"grid_template_rows": null,
|
@@ -338,6 +348,7 @@
|
|
338 |
"cd9c7016caae47c1b41fb2608c78b0bf": {
|
339 |
"model_module": "@jupyter-widgets/controls",
|
340 |
"model_name": "FloatProgressModel",
|
|
|
341 |
"state": {
|
342 |
"_view_name": "ProgressView",
|
343 |
"style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
|
@@ -361,6 +372,7 @@
|
|
361 |
"36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
|
362 |
"model_module": "@jupyter-widgets/controls",
|
363 |
"model_name": "HTMLModel",
|
|
|
364 |
"state": {
|
365 |
"_view_name": "HTMLView",
|
366 |
"style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
|
@@ -381,6 +393,7 @@
|
|
381 |
"c22f207311cf4fb69bd9328eabfd4ebb": {
|
382 |
"model_module": "@jupyter-widgets/controls",
|
383 |
"model_name": "ProgressStyleModel",
|
|
|
384 |
"state": {
|
385 |
"_view_name": "StyleView",
|
386 |
"_model_name": "ProgressStyleModel",
|
@@ -396,6 +409,7 @@
|
|
396 |
"5a38c6d83a264bedbf7efe6e97eba953": {
|
397 |
"model_module": "@jupyter-widgets/base",
|
398 |
"model_name": "LayoutModel",
|
|
|
399 |
"state": {
|
400 |
"_view_name": "LayoutView",
|
401 |
"grid_template_rows": null,
|
@@ -447,6 +461,7 @@
|
|
447 |
"037563a7eadd4ac5abb7249a2914d346": {
|
448 |
"model_module": "@jupyter-widgets/controls",
|
449 |
"model_name": "DescriptionStyleModel",
|
|
|
450 |
"state": {
|
451 |
"_view_name": "StyleView",
|
452 |
"_model_name": "DescriptionStyleModel",
|
@@ -461,6 +476,7 @@
|
|
461 |
"3975e7ed0b704990b1fa05909a9bb9b6": {
|
462 |
"model_module": "@jupyter-widgets/base",
|
463 |
"model_name": "LayoutModel",
|
|
|
464 |
"state": {
|
465 |
"_view_name": "LayoutView",
|
466 |
"grid_template_rows": null,
|
@@ -512,6 +528,7 @@
|
|
512 |
"f9f1fdc3819a4142b85304cd3c6358a2": {
|
513 |
"model_module": "@jupyter-widgets/controls",
|
514 |
"model_name": "HBoxModel",
|
|
|
515 |
"state": {
|
516 |
"_view_name": "HBoxView",
|
517 |
"_dom_classes": [],
|
@@ -532,6 +549,7 @@
|
|
532 |
"ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
|
533 |
"model_module": "@jupyter-widgets/base",
|
534 |
"model_name": "LayoutModel",
|
|
|
535 |
"state": {
|
536 |
"_view_name": "LayoutView",
|
537 |
"grid_template_rows": null,
|
@@ -583,6 +601,7 @@
|
|
583 |
"29d42e94b3b34c86a117b623da68faed": {
|
584 |
"model_module": "@jupyter-widgets/controls",
|
585 |
"model_name": "FloatProgressModel",
|
|
|
586 |
"state": {
|
587 |
"_view_name": "ProgressView",
|
588 |
"style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
|
@@ -606,6 +625,7 @@
|
|
606 |
"8b73de7dbdfe40dbbb39fb593520b984": {
|
607 |
"model_module": "@jupyter-widgets/controls",
|
608 |
"model_name": "HTMLModel",
|
|
|
609 |
"state": {
|
610 |
"_view_name": "HTMLView",
|
611 |
"style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
|
@@ -626,6 +646,7 @@
|
|
626 |
"8ce4d20d004a4382afa0abdd3b1f7191": {
|
627 |
"model_module": "@jupyter-widgets/controls",
|
628 |
"model_name": "ProgressStyleModel",
|
|
|
629 |
"state": {
|
630 |
"_view_name": "StyleView",
|
631 |
"_model_name": "ProgressStyleModel",
|
@@ -641,6 +662,7 @@
|
|
641 |
"efc4812245c8459c92e6436889b4f600": {
|
642 |
"model_module": "@jupyter-widgets/base",
|
643 |
"model_name": "LayoutModel",
|
|
|
644 |
"state": {
|
645 |
"_view_name": "LayoutView",
|
646 |
"grid_template_rows": null,
|
@@ -692,6 +714,7 @@
|
|
692 |
"717ccef4df1f477abb51814650eb47da": {
|
693 |
"model_module": "@jupyter-widgets/controls",
|
694 |
"model_name": "DescriptionStyleModel",
|
|
|
695 |
"state": {
|
696 |
"_view_name": "StyleView",
|
697 |
"_model_name": "DescriptionStyleModel",
|
@@ -706,6 +729,7 @@
|
|
706 |
"7dba58f0391c485a86e34e8039ec6189": {
|
707 |
"model_module": "@jupyter-widgets/base",
|
708 |
"model_name": "LayoutModel",
|
|
|
709 |
"state": {
|
710 |
"_view_name": "LayoutView",
|
711 |
"grid_template_rows": null,
|
@@ -804,8 +828,7 @@
|
|
804 |
"source": [
|
805 |
"!pip install -q transformers flax\n",
|
806 |
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX\n",
|
807 |
-
"!
|
808 |
-
"%cd dalle-mini/"
|
809 |
],
|
810 |
"execution_count": null,
|
811 |
"outputs": []
|
@@ -833,7 +856,7 @@
|
|
833 |
"import random\n",
|
834 |
"from tqdm.notebook import tqdm, trange"
|
835 |
],
|
836 |
-
"execution_count":
|
837 |
"outputs": []
|
838 |
},
|
839 |
{
|
@@ -846,7 +869,7 @@
|
|
846 |
"DALLE_REPO = 'flax-community/dalle-mini'\n",
|
847 |
"DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
|
848 |
],
|
849 |
-
"execution_count":
|
850 |
"outputs": []
|
851 |
},
|
852 |
{
|
@@ -871,7 +894,7 @@
|
|
871 |
"# set a prompt\n",
|
872 |
"prompt = 'picture of a waterfall under the sunset'"
|
873 |
],
|
874 |
-
"execution_count":
|
875 |
"outputs": []
|
876 |
},
|
877 |
{
|
@@ -888,7 +911,7 @@
|
|
888 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
|
889 |
"tokenized_prompt"
|
890 |
],
|
891 |
-
"execution_count":
|
892 |
"outputs": [
|
893 |
{
|
894 |
"output_type": "execute_result",
|
@@ -956,7 +979,7 @@
|
|
956 |
"subkeys = jax.random.split(key, num=n_predictions)\n",
|
957 |
"subkeys"
|
958 |
],
|
959 |
-
"execution_count":
|
960 |
"outputs": [
|
961 |
{
|
962 |
"output_type": "execute_result",
|
@@ -1004,7 +1027,7 @@
|
|
1004 |
"encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
|
1005 |
"encoded_images[0]"
|
1006 |
],
|
1007 |
-
"execution_count":
|
1008 |
"outputs": [
|
1009 |
{
|
1010 |
"output_type": "display_data",
|
@@ -1099,7 +1122,7 @@
|
|
1099 |
"encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
|
1100 |
"encoded_images[0]"
|
1101 |
],
|
1102 |
-
"execution_count":
|
1103 |
"outputs": [
|
1104 |
{
|
1105 |
"output_type": "execute_result",
|
@@ -1167,7 +1190,7 @@
|
|
1167 |
"source": [
|
1168 |
"encoded_images[0].shape"
|
1169 |
],
|
1170 |
-
"execution_count":
|
1171 |
"outputs": [
|
1172 |
{
|
1173 |
"output_type": "execute_result",
|
@@ -1204,7 +1227,7 @@
|
|
1204 |
"import numpy as np\n",
|
1205 |
"from PIL import Image"
|
1206 |
],
|
1207 |
-
"execution_count":
|
1208 |
"outputs": []
|
1209 |
},
|
1210 |
{
|
@@ -1217,7 +1240,7 @@
|
|
1217 |
"VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
|
1218 |
"VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
|
1219 |
],
|
1220 |
-
"execution_count":
|
1221 |
"outputs": []
|
1222 |
},
|
1223 |
{
|
@@ -1233,7 +1256,7 @@
|
|
1233 |
"# set up VQGAN\n",
|
1234 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
|
1235 |
],
|
1236 |
-
"execution_count":
|
1237 |
"outputs": [
|
1238 |
{
|
1239 |
"output_type": "stream",
|
@@ -1269,7 +1292,7 @@
|
|
1269 |
"decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
|
1270 |
"decoded_images[0]"
|
1271 |
],
|
1272 |
-
"execution_count":
|
1273 |
"outputs": [
|
1274 |
{
|
1275 |
"output_type": "display_data",
|
@@ -1373,7 +1396,7 @@
|
|
1373 |
"# normalize images\n",
|
1374 |
"clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
|
1375 |
],
|
1376 |
-
"execution_count":
|
1377 |
"outputs": []
|
1378 |
},
|
1379 |
{
|
@@ -1385,7 +1408,7 @@
|
|
1385 |
"# convert to image\n",
|
1386 |
"images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
|
1387 |
],
|
1388 |
-
"execution_count":
|
1389 |
"outputs": []
|
1390 |
},
|
1391 |
{
|
@@ -1402,7 +1425,7 @@
|
|
1402 |
"# display an image\n",
|
1403 |
"images[0]"
|
1404 |
],
|
1405 |
-
"execution_count":
|
1406 |
"outputs": [
|
1407 |
{
|
1408 |
"output_type": "execute_result",
|
@@ -1438,7 +1461,7 @@
|
|
1438 |
"source": [
|
1439 |
"from transformers import CLIPProcessor, FlaxCLIPModel"
|
1440 |
],
|
1441 |
-
"execution_count":
|
1442 |
"outputs": []
|
1443 |
},
|
1444 |
{
|
@@ -1474,7 +1497,7 @@
|
|
1474 |
"logits = clip(**inputs).logits_per_image\n",
|
1475 |
"scores = jax.nn.softmax(logits, axis=0).squeeze() # normalize and sum all scores to 1"
|
1476 |
],
|
1477 |
-
"execution_count":
|
1478 |
"outputs": []
|
1479 |
},
|
1480 |
{
|
@@ -1495,7 +1518,7 @@
|
|
1495 |
" display(images[idx])\n",
|
1496 |
" print()"
|
1497 |
],
|
1498 |
-
"execution_count":
|
1499 |
"outputs": [
|
1500 |
{
|
1501 |
"output_type": "stream",
|
@@ -1690,7 +1713,7 @@
|
|
1690 |
"from flax.training.common_utils import shard\n",
|
1691 |
"from flax.jax_utils import replicate"
|
1692 |
],
|
1693 |
-
"execution_count":
|
1694 |
"outputs": []
|
1695 |
},
|
1696 |
{
|
@@ -1706,7 +1729,7 @@
|
|
1706 |
"# check we can access TPU's or GPU's\n",
|
1707 |
"jax.devices()"
|
1708 |
],
|
1709 |
-
"execution_count":
|
1710 |
"outputs": [
|
1711 |
{
|
1712 |
"output_type": "execute_result",
|
@@ -1744,7 +1767,7 @@
|
|
1744 |
"# one set of inputs per device\n",
|
1745 |
"prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
|
1746 |
],
|
1747 |
-
"execution_count":
|
1748 |
"outputs": []
|
1749 |
},
|
1750 |
{
|
@@ -1757,7 +1780,7 @@
|
|
1757 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
|
1758 |
"tokenized_prompt = shard(tokenized_prompt)"
|
1759 |
],
|
1760 |
-
"execution_count":
|
1761 |
"outputs": []
|
1762 |
},
|
1763 |
{
|
@@ -1793,7 +1816,7 @@
|
|
1793 |
"def p_decode(indices, params):\n",
|
1794 |
" return vqgan.decode_code(indices, params=params)"
|
1795 |
],
|
1796 |
-
"execution_count":
|
1797 |
"outputs": []
|
1798 |
},
|
1799 |
{
|
@@ -1834,7 +1857,7 @@
|
|
1834 |
" for img in decoded_images:\n",
|
1835 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
1836 |
],
|
1837 |
-
"execution_count":
|
1838 |
"outputs": [
|
1839 |
{
|
1840 |
"output_type": "display_data",
|
@@ -1877,7 +1900,7 @@
|
|
1877 |
" display(img)\n",
|
1878 |
" print()"
|
1879 |
],
|
1880 |
-
"execution_count":
|
1881 |
"outputs": [
|
1882 |
{
|
1883 |
"output_type": "display_data",
|
|
|
6 |
"name": "DALL·E mini - Inference pipeline.ipynb",
|
7 |
"provenance": [],
|
8 |
"collapsed_sections": [],
|
9 |
+
"authorship_tag": "ABX9TyMUjEt1XMLq+6/GhSnVFsSx",
|
10 |
"include_colab_link": true
|
11 |
},
|
12 |
"kernelspec": {
|
|
|
22 |
"49304912717a4995ae45d04a59d1f50e": {
|
23 |
"model_module": "@jupyter-widgets/controls",
|
24 |
"model_name": "HBoxModel",
|
25 |
+
"model_module_version": "1.5.0",
|
26 |
"state": {
|
27 |
"_view_name": "HBoxView",
|
28 |
"_dom_classes": [],
|
|
|
43 |
"5fd9f97986024e8db560a6737ade9e2e": {
|
44 |
"model_module": "@jupyter-widgets/base",
|
45 |
"model_name": "LayoutModel",
|
46 |
+
"model_module_version": "1.2.0",
|
47 |
"state": {
|
48 |
"_view_name": "LayoutView",
|
49 |
"grid_template_rows": null,
|
|
|
95 |
"caced43e3a4c493b98fb07cb41db045c": {
|
96 |
"model_module": "@jupyter-widgets/controls",
|
97 |
"model_name": "FloatProgressModel",
|
98 |
+
"model_module_version": "1.5.0",
|
99 |
"state": {
|
100 |
"_view_name": "ProgressView",
|
101 |
"style": "IPY_MODEL_40c54b9454d346aabd197f2bcf189467",
|
|
|
119 |
"0acc161f2e9948b68b3fc4e57ef333c9": {
|
120 |
"model_module": "@jupyter-widgets/controls",
|
121 |
"model_name": "HTMLModel",
|
122 |
+
"model_module_version": "1.5.0",
|
123 |
"state": {
|
124 |
"_view_name": "HTMLView",
|
125 |
"style": "IPY_MODEL_7e7c488f57fc4acb8d261e2db81d61f0",
|
|
|
140 |
"40c54b9454d346aabd197f2bcf189467": {
|
141 |
"model_module": "@jupyter-widgets/controls",
|
142 |
"model_name": "ProgressStyleModel",
|
143 |
+
"model_module_version": "1.5.0",
|
144 |
"state": {
|
145 |
"_view_name": "StyleView",
|
146 |
"_model_name": "ProgressStyleModel",
|
|
|
156 |
"8b25334a48244a14aa9ba0176887e655": {
|
157 |
"model_module": "@jupyter-widgets/base",
|
158 |
"model_name": "LayoutModel",
|
159 |
+
"model_module_version": "1.2.0",
|
160 |
"state": {
|
161 |
"_view_name": "LayoutView",
|
162 |
"grid_template_rows": null,
|
|
|
208 |
"7e7c488f57fc4acb8d261e2db81d61f0": {
|
209 |
"model_module": "@jupyter-widgets/controls",
|
210 |
"model_name": "DescriptionStyleModel",
|
211 |
+
"model_module_version": "1.5.0",
|
212 |
"state": {
|
213 |
"_view_name": "StyleView",
|
214 |
"_model_name": "DescriptionStyleModel",
|
|
|
223 |
"72c401062a5348b1a366dffb5a403568": {
|
224 |
"model_module": "@jupyter-widgets/base",
|
225 |
"model_name": "LayoutModel",
|
226 |
+
"model_module_version": "1.2.0",
|
227 |
"state": {
|
228 |
"_view_name": "LayoutView",
|
229 |
"grid_template_rows": null,
|
|
|
275 |
"022c124dfff348f285335732781b0887": {
|
276 |
"model_module": "@jupyter-widgets/controls",
|
277 |
"model_name": "HBoxModel",
|
278 |
+
"model_module_version": "1.5.0",
|
279 |
"state": {
|
280 |
"_view_name": "HBoxView",
|
281 |
"_dom_classes": [],
|
|
|
296 |
"a44e47e9d26c4deb81a5a11a9db92a9f": {
|
297 |
"model_module": "@jupyter-widgets/base",
|
298 |
"model_name": "LayoutModel",
|
299 |
+
"model_module_version": "1.2.0",
|
300 |
"state": {
|
301 |
"_view_name": "LayoutView",
|
302 |
"grid_template_rows": null,
|
|
|
348 |
"cd9c7016caae47c1b41fb2608c78b0bf": {
|
349 |
"model_module": "@jupyter-widgets/controls",
|
350 |
"model_name": "FloatProgressModel",
|
351 |
+
"model_module_version": "1.5.0",
|
352 |
"state": {
|
353 |
"_view_name": "ProgressView",
|
354 |
"style": "IPY_MODEL_c22f207311cf4fb69bd9328eabfd4ebb",
|
|
|
372 |
"36ff1d0fea4b47e2ae35aa6bfae6a5e8": {
|
373 |
"model_module": "@jupyter-widgets/controls",
|
374 |
"model_name": "HTMLModel",
|
375 |
+
"model_module_version": "1.5.0",
|
376 |
"state": {
|
377 |
"_view_name": "HTMLView",
|
378 |
"style": "IPY_MODEL_037563a7eadd4ac5abb7249a2914d346",
|
|
|
393 |
"c22f207311cf4fb69bd9328eabfd4ebb": {
|
394 |
"model_module": "@jupyter-widgets/controls",
|
395 |
"model_name": "ProgressStyleModel",
|
396 |
+
"model_module_version": "1.5.0",
|
397 |
"state": {
|
398 |
"_view_name": "StyleView",
|
399 |
"_model_name": "ProgressStyleModel",
|
|
|
409 |
"5a38c6d83a264bedbf7efe6e97eba953": {
|
410 |
"model_module": "@jupyter-widgets/base",
|
411 |
"model_name": "LayoutModel",
|
412 |
+
"model_module_version": "1.2.0",
|
413 |
"state": {
|
414 |
"_view_name": "LayoutView",
|
415 |
"grid_template_rows": null,
|
|
|
461 |
"037563a7eadd4ac5abb7249a2914d346": {
|
462 |
"model_module": "@jupyter-widgets/controls",
|
463 |
"model_name": "DescriptionStyleModel",
|
464 |
+
"model_module_version": "1.5.0",
|
465 |
"state": {
|
466 |
"_view_name": "StyleView",
|
467 |
"_model_name": "DescriptionStyleModel",
|
|
|
476 |
"3975e7ed0b704990b1fa05909a9bb9b6": {
|
477 |
"model_module": "@jupyter-widgets/base",
|
478 |
"model_name": "LayoutModel",
|
479 |
+
"model_module_version": "1.2.0",
|
480 |
"state": {
|
481 |
"_view_name": "LayoutView",
|
482 |
"grid_template_rows": null,
|
|
|
528 |
"f9f1fdc3819a4142b85304cd3c6358a2": {
|
529 |
"model_module": "@jupyter-widgets/controls",
|
530 |
"model_name": "HBoxModel",
|
531 |
+
"model_module_version": "1.5.0",
|
532 |
"state": {
|
533 |
"_view_name": "HBoxView",
|
534 |
"_dom_classes": [],
|
|
|
549 |
"ea9ed54e7c9d4ead8b3e1ff4cb27fa61": {
|
550 |
"model_module": "@jupyter-widgets/base",
|
551 |
"model_name": "LayoutModel",
|
552 |
+
"model_module_version": "1.2.0",
|
553 |
"state": {
|
554 |
"_view_name": "LayoutView",
|
555 |
"grid_template_rows": null,
|
|
|
601 |
"29d42e94b3b34c86a117b623da68faed": {
|
602 |
"model_module": "@jupyter-widgets/controls",
|
603 |
"model_name": "FloatProgressModel",
|
604 |
+
"model_module_version": "1.5.0",
|
605 |
"state": {
|
606 |
"_view_name": "ProgressView",
|
607 |
"style": "IPY_MODEL_8ce4d20d004a4382afa0abdd3b1f7191",
|
|
|
625 |
"8b73de7dbdfe40dbbb39fb593520b984": {
|
626 |
"model_module": "@jupyter-widgets/controls",
|
627 |
"model_name": "HTMLModel",
|
628 |
+
"model_module_version": "1.5.0",
|
629 |
"state": {
|
630 |
"_view_name": "HTMLView",
|
631 |
"style": "IPY_MODEL_717ccef4df1f477abb51814650eb47da",
|
|
|
646 |
"8ce4d20d004a4382afa0abdd3b1f7191": {
|
647 |
"model_module": "@jupyter-widgets/controls",
|
648 |
"model_name": "ProgressStyleModel",
|
649 |
+
"model_module_version": "1.5.0",
|
650 |
"state": {
|
651 |
"_view_name": "StyleView",
|
652 |
"_model_name": "ProgressStyleModel",
|
|
|
662 |
"efc4812245c8459c92e6436889b4f600": {
|
663 |
"model_module": "@jupyter-widgets/base",
|
664 |
"model_name": "LayoutModel",
|
665 |
+
"model_module_version": "1.2.0",
|
666 |
"state": {
|
667 |
"_view_name": "LayoutView",
|
668 |
"grid_template_rows": null,
|
|
|
714 |
"717ccef4df1f477abb51814650eb47da": {
|
715 |
"model_module": "@jupyter-widgets/controls",
|
716 |
"model_name": "DescriptionStyleModel",
|
717 |
+
"model_module_version": "1.5.0",
|
718 |
"state": {
|
719 |
"_view_name": "StyleView",
|
720 |
"_model_name": "DescriptionStyleModel",
|
|
|
729 |
"7dba58f0391c485a86e34e8039ec6189": {
|
730 |
"model_module": "@jupyter-widgets/base",
|
731 |
"model_name": "LayoutModel",
|
732 |
+
"model_module_version": "1.2.0",
|
733 |
"state": {
|
734 |
"_view_name": "LayoutView",
|
735 |
"grid_template_rows": null,
|
|
|
828 |
"source": [
|
829 |
"!pip install -q transformers flax\n",
|
830 |
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git # VQGAN model in JAX\n",
|
831 |
+
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git # Model files"
|
|
|
832 |
],
|
833 |
"execution_count": null,
|
834 |
"outputs": []
|
|
|
856 |
"import random\n",
|
857 |
"from tqdm.notebook import tqdm, trange"
|
858 |
],
|
859 |
+
"execution_count": null,
|
860 |
"outputs": []
|
861 |
},
|
862 |
{
|
|
|
869 |
"DALLE_REPO = 'flax-community/dalle-mini'\n",
|
870 |
"DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'"
|
871 |
],
|
872 |
+
"execution_count": null,
|
873 |
"outputs": []
|
874 |
},
|
875 |
{
|
|
|
894 |
"# set a prompt\n",
|
895 |
"prompt = 'picture of a waterfall under the sunset'"
|
896 |
],
|
897 |
+
"execution_count": null,
|
898 |
"outputs": []
|
899 |
},
|
900 |
{
|
|
|
911 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128)\n",
|
912 |
"tokenized_prompt"
|
913 |
],
|
914 |
+
"execution_count": null,
|
915 |
"outputs": [
|
916 |
{
|
917 |
"output_type": "execute_result",
|
|
|
979 |
"subkeys = jax.random.split(key, num=n_predictions)\n",
|
980 |
"subkeys"
|
981 |
],
|
982 |
+
"execution_count": null,
|
983 |
"outputs": [
|
984 |
{
|
985 |
"output_type": "execute_result",
|
|
|
1027 |
"encoded_images = [model.generate(**tokenized_prompt, do_sample=True, num_beams=1, prng_key=subkey) for subkey in tqdm(subkeys)]\n",
|
1028 |
"encoded_images[0]"
|
1029 |
],
|
1030 |
+
"execution_count": null,
|
1031 |
"outputs": [
|
1032 |
{
|
1033 |
"output_type": "display_data",
|
|
|
1122 |
"encoded_images = [img.sequences[..., 1:] for img in encoded_images]\n",
|
1123 |
"encoded_images[0]"
|
1124 |
],
|
1125 |
+
"execution_count": null,
|
1126 |
"outputs": [
|
1127 |
{
|
1128 |
"output_type": "execute_result",
|
|
|
1190 |
"source": [
|
1191 |
"encoded_images[0].shape"
|
1192 |
],
|
1193 |
+
"execution_count": null,
|
1194 |
"outputs": [
|
1195 |
{
|
1196 |
"output_type": "execute_result",
|
|
|
1227 |
"import numpy as np\n",
|
1228 |
"from PIL import Image"
|
1229 |
],
|
1230 |
+
"execution_count": null,
|
1231 |
"outputs": []
|
1232 |
},
|
1233 |
{
|
|
|
1240 |
"VQGAN_REPO = 'flax-community/vqgan_f16_16384'\n",
|
1241 |
"VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'"
|
1242 |
],
|
1243 |
+
"execution_count": null,
|
1244 |
"outputs": []
|
1245 |
},
|
1246 |
{
|
|
|
1256 |
"# set up VQGAN\n",
|
1257 |
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)"
|
1258 |
],
|
1259 |
+
"execution_count": null,
|
1260 |
"outputs": [
|
1261 |
{
|
1262 |
"output_type": "stream",
|
|
|
1292 |
"decoded_images = [vqgan.decode_code(encoded_image) for encoded_image in tqdm(encoded_images)]\n",
|
1293 |
"decoded_images[0]"
|
1294 |
],
|
1295 |
+
"execution_count": null,
|
1296 |
"outputs": [
|
1297 |
{
|
1298 |
"output_type": "display_data",
|
|
|
1396 |
"# normalize images\n",
|
1397 |
"clipped_images = [img.squeeze().clip(0., 1.) for img in decoded_images]"
|
1398 |
],
|
1399 |
+
"execution_count": null,
|
1400 |
"outputs": []
|
1401 |
},
|
1402 |
{
|
|
|
1408 |
"# convert to image\n",
|
1409 |
"images = [Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) for img in clipped_images]"
|
1410 |
],
|
1411 |
+
"execution_count": null,
|
1412 |
"outputs": []
|
1413 |
},
|
1414 |
{
|
|
|
1425 |
"# display an image\n",
|
1426 |
"images[0]"
|
1427 |
],
|
1428 |
+
"execution_count": null,
|
1429 |
"outputs": [
|
1430 |
{
|
1431 |
"output_type": "execute_result",
|
|
|
1461 |
"source": [
|
1462 |
"from transformers import CLIPProcessor, FlaxCLIPModel"
|
1463 |
],
|
1464 |
+
"execution_count": null,
|
1465 |
"outputs": []
|
1466 |
},
|
1467 |
{
|
|
|
1497 |
"logits = clip(**inputs).logits_per_image\n",
|
1498 |
"scores = jax.nn.softmax(logits, axis=0).squeeze() # normalize and sum all scores to 1"
|
1499 |
],
|
1500 |
+
"execution_count": null,
|
1501 |
"outputs": []
|
1502 |
},
|
1503 |
{
|
|
|
1518 |
" display(images[idx])\n",
|
1519 |
" print()"
|
1520 |
],
|
1521 |
+
"execution_count": null,
|
1522 |
"outputs": [
|
1523 |
{
|
1524 |
"output_type": "stream",
|
|
|
1713 |
"from flax.training.common_utils import shard\n",
|
1714 |
"from flax.jax_utils import replicate"
|
1715 |
],
|
1716 |
+
"execution_count": null,
|
1717 |
"outputs": []
|
1718 |
},
|
1719 |
{
|
|
|
1729 |
"# check we can access TPU's or GPU's\n",
|
1730 |
"jax.devices()"
|
1731 |
],
|
1732 |
+
"execution_count": null,
|
1733 |
"outputs": [
|
1734 |
{
|
1735 |
"output_type": "execute_result",
|
|
|
1767 |
"# one set of inputs per device\n",
|
1768 |
"prompt = ['picture of a waterfall under the sunset'] * jax.device_count()"
|
1769 |
],
|
1770 |
+
"execution_count": null,
|
1771 |
"outputs": []
|
1772 |
},
|
1773 |
{
|
|
|
1780 |
"tokenized_prompt = tokenizer(prompt, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
|
1781 |
"tokenized_prompt = shard(tokenized_prompt)"
|
1782 |
],
|
1783 |
+
"execution_count": null,
|
1784 |
"outputs": []
|
1785 |
},
|
1786 |
{
|
|
|
1816 |
"def p_decode(indices, params):\n",
|
1817 |
" return vqgan.decode_code(indices, params=params)"
|
1818 |
],
|
1819 |
+
"execution_count": null,
|
1820 |
"outputs": []
|
1821 |
},
|
1822 |
{
|
|
|
1857 |
" for img in decoded_images:\n",
|
1858 |
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
1859 |
],
|
1860 |
+
"execution_count": null,
|
1861 |
"outputs": [
|
1862 |
{
|
1863 |
"output_type": "display_data",
|
|
|
1900 |
" display(img)\n",
|
1901 |
" print()"
|
1902 |
],
|
1903 |
+
"execution_count": null,
|
1904 |
"outputs": [
|
1905 |
{
|
1906 |
"output_type": "display_data",
|
dev/requirements.txt
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
-
# Note: install with the following command:
|
2 |
-
# pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
3 |
-
# Otherwise it won't find the appropriate libtpu_nightly
|
4 |
requests
|
|
|
5 |
jax[tpu]>=0.2.16
|
6 |
-
|
7 |
-
|
8 |
flax
|
9 |
jupyter
|
10 |
wandb
|
|
|
|
|
|
|
|
|
1 |
requests
|
2 |
+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
3 |
jax[tpu]>=0.2.16
|
4 |
+
transformers
|
5 |
+
datasets
|
6 |
flax
|
7 |
jupyter
|
8 |
wandb
|
requirements.txt
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
# Requirements for huggingface spaces
|
2 |
-
streamlit>=0.84.2
|
|
|
|
|
|
setup.cfg
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[metadata]
|
2 |
+
name = dalle_mini
|
3 |
+
version = attr: dalle_mini.__version__
|
4 |
+
description = DALL·E mini - Generate images from a text prompt
|
5 |
+
long_description = file: README.md
|
6 |
+
long_description_content_type = text/markdown
|
7 |
+
url = https://github.com/borisdayma/dalle-mini
|
8 |
+
project_urls =
|
9 |
+
Bug Tracker = https://github.com/borisdayma/dalle-mini/issues
|
10 |
+
|
11 |
+
[options]
|
12 |
+
packages = find:
|
13 |
+
install_requires =
|
14 |
+
transformers
|
15 |
+
jax
|
16 |
+
flax
|
setup.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
setup()
|