boris commited on
Commit
24d30c9
2 Parent(s): 6016fc0 74974be

Merge pull request #108 from borisdayma/feat-inf

Browse files
app/app.py CHANGED
@@ -2,31 +2,10 @@
2
  # coding: utf-8
3
 
4
  from dalle_mini.backend import ServiceError, get_images_from_backend
5
-
6
  import streamlit as st
7
 
8
- # streamlit.session_state is not available in Huggingface spaces.
9
- # Session state hack https://huggingface.slack.com/archives/C025LJDP962/p1626527367443200?thread_ts=1626525999.440500&cid=C025LJDP962
10
-
11
- from streamlit.report_thread import get_report_ctx
12
- def query_cache(q_emb=None):
13
- ctx = get_report_ctx()
14
- session_id = ctx.session_id
15
- session = st.server.server.Server.get_current()._get_session_info(session_id).session
16
- if not hasattr(session, "_query_state"):
17
- setattr(session, "_query_state", q_emb)
18
- if q_emb:
19
- session._query_state = q_emb
20
- return session._query_state
21
-
22
- def set_run_again(state):
23
- query_cache(state)
24
-
25
- def should_run_again():
26
- state = query_cache()
27
- return state if state is not None else False
28
-
29
- st.sidebar.markdown("""
30
  <style>
31
  .aligncenter {
32
  text-align: center;
@@ -35,8 +14,11 @@ st.sidebar.markdown("""
35
  <p class="aligncenter">
36
  <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
37
  </p>
38
- """, unsafe_allow_html=True)
39
- st.sidebar.markdown("""
 
 
 
40
  ___
41
  <p style='text-align: center'>
42
  DALL·E mini is an AI model that generates images from any prompt you give!
@@ -47,21 +29,20 @@ Created by Boris Dayma et al. 2021
47
  <br/>
48
  <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
49
  </p>
50
- """, unsafe_allow_html=True)
 
 
51
 
52
- st.header('DALL·E mini')
53
- st.subheader('Generate images from text')
54
 
55
  prompt = st.text_input("What do you want to see?")
56
 
57
- test = st.empty()
58
  DEBUG = False
59
- if prompt != "" or (should_run_again and prompt != ""):
60
  container = st.empty()
61
- # The following mimics `streamlit.info()`.
62
- # I tried to get the secondary background color using `components.streamlit.config.get_options_for_section("theme")["secondaryBackgroundColor"]`
63
- # but it returns None.
64
- container.markdown(f"""
65
  <style> p {{ margin:0 }} div {{ margin:0 }} </style>
66
  <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
67
  <div class="stAlert">
@@ -78,32 +59,39 @@ if prompt != "" or (should_run_again and prompt != ""):
78
  </div>
79
  </div>
80
  <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
81
- """, unsafe_allow_html=True)
 
 
82
 
83
  try:
84
  backend_url = st.secrets["BACKEND_SERVER"]
85
  print(f"Getting selections: {prompt}")
86
  selected = get_images_from_backend(prompt, backend_url)
87
 
88
- cols = st.columns(4)
 
 
89
  for i, img in enumerate(selected):
90
- cols[i%4].image(img)
91
-
92
  container.markdown(f"**{prompt}**")
93
-
94
- set_run_again(st.button('Again!', key='again_button'))
95
-
96
  except ServiceError as error:
97
  container.text(f"Service unavailable, status: {error.status_code}")
98
  except KeyError:
99
  if DEBUG:
100
- container.markdown("""
 
101
  **Error: BACKEND_SERVER unset**
102
 
103
  Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
104
  ```
105
  BACKEND_SERVER="<server url>"
106
  ```
107
- """)
 
108
  else:
109
- container.markdown('Error -5, please try again or [report it](mailto:[email protected]).')
 
 
 
2
  # coding: utf-8
3
 
4
  from dalle_mini.backend import ServiceError, get_images_from_backend
 
5
  import streamlit as st
6
 
7
+ st.sidebar.markdown(
8
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  <style>
10
  .aligncenter {
11
  text-align: center;
 
14
  <p class="aligncenter">
15
  <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
16
  </p>
17
+ """,
18
+ unsafe_allow_html=True,
19
+ )
20
+ st.sidebar.markdown(
21
+ """
22
  ___
23
  <p style='text-align: center'>
24
  DALL·E mini is an AI model that generates images from any prompt you give!
 
29
  <br/>
30
  <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
31
  </p>
32
+ """,
33
+ unsafe_allow_html=True,
34
+ )
35
 
36
+ st.header("DALL·E mini")
37
+ st.subheader("Generate images from text")
38
 
39
  prompt = st.text_input("What do you want to see?")
40
 
 
41
  DEBUG = False
42
+ if prompt != "":
43
  container = st.empty()
44
+ container.markdown(
45
+ f"""
 
 
46
  <style> p {{ margin:0 }} div {{ margin:0 }} </style>
47
  <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
48
  <div class="stAlert">
 
59
  </div>
60
  </div>
61
  <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
62
+ """,
63
+ unsafe_allow_html=True,
64
+ )
65
 
66
  try:
67
  backend_url = st.secrets["BACKEND_SERVER"]
68
  print(f"Getting selections: {prompt}")
69
  selected = get_images_from_backend(prompt, backend_url)
70
 
71
+ margin = 0.1 # for better position of zoom in arrow
72
+ n_columns = 3
73
+ cols = st.columns([1] + [margin, 1] * (n_columns - 1))
74
  for i, img in enumerate(selected):
75
+ cols[(i % n_columns) * 2].image(img)
 
76
  container.markdown(f"**{prompt}**")
77
+
78
+ st.button("Again!", key="again_button")
79
+
80
  except ServiceError as error:
81
  container.text(f"Service unavailable, status: {error.status_code}")
82
  except KeyError:
83
  if DEBUG:
84
+ container.markdown(
85
+ """
86
  **Error: BACKEND_SERVER unset**
87
 
88
  Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
89
  ```
90
  BACKEND_SERVER="<server url>"
91
  ```
92
+ """
93
+ )
94
  else:
95
+ container.markdown(
96
+ "Error -5, please try again or [report it](mailto:[email protected])."
97
+ )
dev/inference/README.md DELETED
@@ -1 +0,0 @@
1
- Scripts to generate predictions for assessment and reporting.
 
 
dev/inference/wandb-examples-from-backend.py DELETED
@@ -1,76 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- from PIL import Image, ImageDraw, ImageFont
5
- import wandb
6
- import os
7
-
8
- from dalle_mini.backend import ServiceError, get_images_from_backend
9
- from dalle_mini.helpers import captioned_strip
10
-
11
- os.environ["WANDB_SILENT"] = "true"
12
- os.environ["WANDB_CONSOLE"] = "off"
13
-
14
- def log_to_wandb(prompts):
15
- try:
16
- backend_url = os.environ["BACKEND_SERVER"]
17
- for _ in range(1):
18
- for prompt in prompts:
19
- print(f"Getting selections for: {prompt}")
20
- # make a separate run per prompt
21
- with wandb.init(
22
- entity='wandb',
23
- project='hf-flax-dalle-mini',
24
- job_type='predictions',# tags=['openai'],
25
- config={'prompt': prompt}
26
- ):
27
- imgs = []
28
- selected = get_images_from_backend(prompt, backend_url)
29
- strip = captioned_strip(selected, prompt)
30
- imgs.append(wandb.Image(strip))
31
- wandb.log({"images": imgs})
32
- except ServiceError as error:
33
- print(f"Service unavailable, status: {error.status_code}")
34
- except KeyError:
35
- print("Error: BACKEND_SERVER unset")
36
-
37
- prompts = [
38
- # "white snow covered mountain under blue sky during daytime",
39
- # "aerial view of beach during daytime",
40
- # "aerial view of beach at night",
41
- # "a farmhouse surrounded by beautiful flowers",
42
- # "an armchair in the shape of an avocado",
43
- # "young woman riding her bike trough a forest",
44
- # "a unicorn is passing by a rainbow in a field of flowers",
45
- # "illustration of a baby shark swimming around corals",
46
- # "painting of an oniric forest glade surrounded by tall trees",
47
- # "sunset over green mountains",
48
- # "a forest glade surrounded by tall trees in a sunny Spring morning",
49
- # "fishing village under the moonlight in a serene sunset",
50
- # "cartoon of a carrot with big eyes",
51
- # "still life in the style of Kandinsky",
52
- # "still life in the style of Picasso",
53
- # "a graphite sketch of a gothic cathedral",
54
- # "a graphite sketch of Elon Musk",
55
- # "a watercolor pond with green leaves and yellow flowers",
56
- # "a logo of a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps",
57
- # "happy celebration in a small village in Africa",
58
- # "a logo of an armchair in the shape of an avocado"
59
- # "Pele and Maradona in a hypothetical match",
60
- # "Mohammed Ali and Mike Tyson in a hypothetical match",
61
- # "a storefront that has the word 'openai' written on it",
62
- # "a pentagonal green clock",
63
- # "a collection of glasses is sitting on a table",
64
- # "a small red block sitting on a large green block",
65
- # "an extreme close-up view of a capybara sitting in a field",
66
- # "a cross-section view of a walnut",
67
- # "a professional high-quality emoji of a lovestruck cup of boba",
68
- # "a photo of san francisco's golden gate bridge",
69
- # "an illustration of a baby daikon radish in a tutu walking a dog",
70
- # "a picture of the Eiffel tower on the Moon",
71
- # "a colorful stairway to heaven",
72
- "this is a detailed high-resolution scan of a human brain"
73
- ]
74
-
75
- for _ in range(1):
76
- log_to_wandb(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dev/inference/wandb-examples.py DELETED
@@ -1,163 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- import random
5
-
6
- import jax
7
- from flax.training.common_utils import shard
8
- from flax.jax_utils import replicate, unreplicate
9
-
10
- from transformers.models.bart.modeling_flax_bart import *
11
- from transformers import BartTokenizer, FlaxBartForConditionalGeneration
12
-
13
- import os
14
-
15
- from PIL import Image
16
- import numpy as np
17
- import matplotlib.pyplot as plt
18
-
19
- import torch
20
- import torchvision.transforms as T
21
- import torchvision.transforms.functional as TF
22
- from torchvision.transforms import InterpolationMode
23
-
24
- from dalle_mini.model import CustomFlaxBartForConditionalGeneration
25
- from vqgan_jax.modeling_flax_vqgan import VQModel
26
-
27
- # ## CLIP Scoring
28
- from transformers import CLIPProcessor, FlaxCLIPModel
29
-
30
- import wandb
31
- import os
32
-
33
- from dalle_mini.helpers import captioned_strip
34
-
35
-
36
- os.environ["WANDB_SILENT"] = "true"
37
- os.environ["WANDB_CONSOLE"] = "off"
38
-
39
- # TODO: used for legacy support
40
- BASE_MODEL = 'facebook/bart-large-cnn'
41
-
42
- # set id to None so our latest images don't get overwritten
43
- id = None
44
- run = wandb.init(id=id,
45
- entity='wandb',
46
- project="hf-flax-dalle-mini",
47
- job_type="predictions",
48
- resume="allow"
49
- )
50
- artifact = run.use_artifact('wandb/hf-flax-dalle-mini/model-4oh3u7ca:latest', type='bart_model')
51
- artifact_dir = artifact.download()
52
-
53
- # create our model
54
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
55
-
56
- # TODO: legacy support (earlier models)
57
- tokenizer = BartTokenizer.from_pretrained(BASE_MODEL)
58
- model.config.force_bos_token_to_be_generated = False
59
- model.config.forced_bos_token_id = None
60
- model.config.forced_eos_token_id = None
61
-
62
- vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
63
-
64
- def custom_to_pil(x):
65
- x = np.clip(x, 0., 1.)
66
- x = (255*x).astype(np.uint8)
67
- x = Image.fromarray(x)
68
- if not x.mode == "RGB":
69
- x = x.convert("RGB")
70
- return x
71
-
72
- def generate(input, rng, params):
73
- return model.generate(
74
- **input,
75
- max_length=257,
76
- num_beams=1,
77
- do_sample=True,
78
- prng_key=rng,
79
- eos_token_id=50000,
80
- pad_token_id=50000,
81
- params=params,
82
- )
83
-
84
- def get_images(indices, params):
85
- return vqgan.decode_code(indices, params=params)
86
-
87
- def plot_images(images):
88
- fig = plt.figure(figsize=(40, 20))
89
- columns = 4
90
- rows = 2
91
- plt.subplots_adjust(hspace=0, wspace=0)
92
-
93
- for i in range(1, columns*rows +1):
94
- fig.add_subplot(rows, columns, i)
95
- plt.imshow(images[i-1])
96
- plt.gca().axes.get_yaxis().set_visible(False)
97
- plt.show()
98
-
99
- def stack_reconstructions(images):
100
- w, h = images[0].size[0], images[0].size[1]
101
- img = Image.new("RGB", (len(images)*w, h))
102
- for i, img_ in enumerate(images):
103
- img.paste(img_, (i*w,0))
104
- return img
105
-
106
- p_generate = jax.pmap(generate, "batch")
107
- p_get_images = jax.pmap(get_images, "batch")
108
-
109
- bart_params = replicate(model.params)
110
- vqgan_params = replicate(vqgan.params)
111
-
112
- clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
113
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
114
-
115
- def hallucinate(prompt, num_images=64):
116
- prompt = [prompt] * jax.device_count()
117
- inputs = tokenizer(prompt, return_tensors='jax', padding="max_length", truncation=True, max_length=128).data
118
- inputs = shard(inputs)
119
-
120
- all_images = []
121
- for i in range(num_images // jax.device_count()):
122
- key = random.randint(0, 1e7)
123
- rng = jax.random.PRNGKey(key)
124
- rngs = jax.random.split(rng, jax.local_device_count())
125
- indices = p_generate(inputs, rngs, bart_params).sequences
126
- indices = indices[:, :, 1:]
127
-
128
- images = p_get_images(indices, vqgan_params)
129
- images = np.squeeze(np.asarray(images), 1)
130
- for image in images:
131
- all_images.append(custom_to_pil(image))
132
- return all_images
133
-
134
- def clip_top_k(prompt, images, k=8):
135
- inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
136
- # FIXME: image should be resized and normalized prior to being processed by CLIP
137
- outputs = clip(**inputs)
138
- logits = outputs.logits_per_text
139
- scores = np.array(logits[0]).argsort()[-k:][::-1]
140
- return [images[score] for score in scores]
141
-
142
- def log_to_wandb(prompts):
143
- strips = []
144
- for prompt in prompts:
145
- print(f"Generating candidates for: {prompt}")
146
- images = hallucinate(prompt, num_images=32)
147
- selected = clip_top_k(prompt, images, k=8)
148
- strip = captioned_strip(selected, prompt)
149
- strips.append(wandb.Image(strip))
150
- wandb.log({"images": strips})
151
-
152
- prompts = prompts = [
153
- "white snow covered mountain under blue sky during daytime",
154
- "aerial view of beach during daytime",
155
- "aerial view of beach at night",
156
- "an armchair in the shape of an avocado",
157
- "young woman riding her bike trough a forest",
158
- "rice fields by the mediterranean coast",
159
- "white houses on the hill of a greek coastline",
160
- "illustration of a shark with a baby shark",
161
- ]
162
-
163
- log_to_wandb(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{dev → tools}/inference/inference_pipeline.ipynb RENAMED
File without changes
dev/inference/wandb-backend.ipynb → tools/inference/log_inference_samples.ipynb RENAMED
@@ -24,25 +24,6 @@
24
  "from dalle_mini.text import TextNormalizer"
25
  ]
26
  },
27
- {
28
- "cell_type": "code",
29
- "execution_count": null,
30
- "id": "23e00271-941c-4e1b-b6a9-107a1b77324d",
31
- "metadata": {},
32
- "outputs": [],
33
- "source": [
34
- "run_ids = ['3kaut6e8']\n",
35
- "# Alamy - 3kaut6e8\n",
36
- "# YFCC - to do\n",
37
- "# HF spaces - 4oh3u7ca\n",
38
- "ENTITY, PROJECT = 'wandb', 'hf-flax-dalle-mini'\n",
39
- "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
40
- "normalize_text = False\n",
41
- "latest_only = True # log only latest or all versions\n",
42
- "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
43
- "add_clip_32 = False"
44
- ]
45
- },
46
  {
47
  "cell_type": "code",
48
  "execution_count": null,
@@ -50,13 +31,9 @@
50
  "metadata": {},
51
  "outputs": [],
52
  "source": [
53
- "run_ids = ['2u5lk3uw']\n",
54
- "# poorly shuffled 1nj161cl\n",
55
- "# well shuffled he9rrc3q\n",
56
- "# non normalized 1fwxpyfh ! requires changing normalize_text\n",
57
  "ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
58
- "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
59
- "normalize_text = True\n",
60
  "latest_only = True # log only latest or all versions\n",
61
  "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
62
  "add_clip_32 = False"
@@ -85,7 +62,7 @@
85
  "batch_size = 8\n",
86
  "num_images = 128\n",
87
  "top_k = 8\n",
88
- "text_normalizer = TextNormalizer() if normalize_text else None\n",
89
  "padding_item = 'NONE'\n",
90
  "seed = random.randint(0, 2**32-1)\n",
91
  "key = jax.random.PRNGKey(seed)\n",
@@ -100,11 +77,12 @@
100
  "outputs": [],
101
  "source": [
102
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
103
- "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
104
- "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
105
- "clip_params = replicate(clip.params)\n",
106
  "vqgan_params = replicate(vqgan.params)\n",
107
  "\n",
 
 
 
 
108
  "if add_clip_32:\n",
109
  " clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
110
  " processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
@@ -123,8 +101,8 @@
123
  " return vqgan.decode_code(indices, params=params)\n",
124
  "\n",
125
  "@partial(jax.pmap, axis_name=\"batch\")\n",
126
- "def p_clip(inputs, params):\n",
127
- " logits = clip(params=params, **inputs).logits_per_image\n",
128
  " return logits\n",
129
  "\n",
130
  "if add_clip_32:\n",
@@ -229,7 +207,7 @@
229
  "outputs": [],
230
  "source": [
231
  "run_id = run_ids[0]\n",
232
- "# TODO: turn everything into a class"
233
  ]
234
  },
235
  {
@@ -248,10 +226,8 @@
248
  "for artifact in artifact_versions:\n",
249
  " print(f'Processing artifact: {artifact.name}')\n",
250
  " version = int(artifact.version[1:])\n",
251
- " results = []\n",
252
- " if add_clip_32:\n",
253
- " results32 = []\n",
254
- " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
255
  " \n",
256
  " if latest_only:\n",
257
  " assert last_inference_version is None or version > last_inference_version\n",
@@ -288,7 +264,7 @@
288
  "\n",
289
  " # process one batch of captions\n",
290
  " for batch in tqdm(samples):\n",
291
- " processed_prompts = [text_normalizer(x) for x in batch] if normalize_text else list(batch)\n",
292
  "\n",
293
  " # repeat the prompts to distribute over each device and tokenize\n",
294
  " processed_prompts = processed_prompts * jax.device_count()\n",
@@ -297,7 +273,7 @@
297
  "\n",
298
  " # generate images\n",
299
  " images = []\n",
300
- " pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=None)\n",
301
  " for i in pbar:\n",
302
  " key, subkey = jax.random.split(key)\n",
303
  " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
@@ -307,34 +283,13 @@
307
  " for img in decoded_images:\n",
308
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
309
  "\n",
310
- " # get clip scores\n",
311
- " pbar.set_description('Calculating CLIP scores')\n",
312
- " clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
313
- " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
314
- " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
315
- " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
316
- " clip_inputs = shard(clip_inputs)\n",
317
- " logits = p_clip(clip_inputs, clip_params)\n",
318
- " logits = logits.reshape(-1, num_images)\n",
319
- " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
320
- " logits = jax.device_get(logits)\n",
321
- " # add to results table\n",
322
- " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
323
- " if sample == padding_item: continue\n",
324
- " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
325
- " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
326
- " top_scores = [scores[x] for x in idx]\n",
327
- " results.append([sample] + top_images + top_scores)\n",
328
- " \n",
329
- " # get clip 32 scores - TODO: this should be refactored as it is same code as above\n",
330
- " if add_clip_32:\n",
331
- " print('Calculating CLIP 32 scores')\n",
332
- " clip_inputs = processor32(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
333
  " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
334
  " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
335
  " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
336
  " clip_inputs = shard(clip_inputs)\n",
337
- " logits = p_clip32(clip_inputs, clip32_params)\n",
338
  " logits = logits.reshape(-1, num_images)\n",
339
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
340
  " logits = jax.device_get(logits)\n",
@@ -342,13 +297,24 @@
342
  " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
343
  " if sample == padding_item: continue\n",
344
  " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
345
- " top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
346
- " top_scores = [scores[x] for x in idx]\n",
347
- " results32.append([sample] + top_images + top_scores)\n",
 
 
 
 
 
 
 
 
 
348
  " pbar.close()\n",
349
  "\n",
 
 
350
  " # log results\n",
351
- " table = wandb.Table(columns=columns, data=results)\n",
352
  " run.log({'Samples': table, 'version': version})\n",
353
  " wandb.finish()\n",
354
  " \n",
@@ -363,15 +329,10 @@
363
  {
364
  "cell_type": "code",
365
  "execution_count": null,
366
- "id": "4e4c7d0c-2848-4f88-b967-82fd571534f1",
367
  "metadata": {},
368
  "outputs": [],
369
- "source": [
370
- "# TODO: not implemented\n",
371
- "def log_runs(runs):\n",
372
- " for run in tqdm(runs):\n",
373
- " log_run(run)"
374
- ]
375
  }
376
  ],
377
  "metadata": {
 
24
  "from dalle_mini.text import TextNormalizer"
25
  ]
26
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  {
28
  "cell_type": "code",
29
  "execution_count": null,
 
31
  "metadata": {},
32
  "outputs": [],
33
  "source": [
34
+ "run_ids = ['63otg87g']\n",
 
 
 
35
  "ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
36
+ "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', 'e93a26e7707683d349bf5d5c41c5b0ef69b677a9'\n",
 
37
  "latest_only = True # log only latest or all versions\n",
38
  "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
39
  "add_clip_32 = False"
 
62
  "batch_size = 8\n",
63
  "num_images = 128\n",
64
  "top_k = 8\n",
65
+ "text_normalizer = TextNormalizer()\n",
66
  "padding_item = 'NONE'\n",
67
  "seed = random.randint(0, 2**32-1)\n",
68
  "key = jax.random.PRNGKey(seed)\n",
 
77
  "outputs": [],
78
  "source": [
79
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
 
 
 
80
  "vqgan_params = replicate(vqgan.params)\n",
81
  "\n",
82
+ "clip16 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
83
+ "processor16 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
84
+ "clip16_params = replicate(clip16.params)\n",
85
+ "\n",
86
  "if add_clip_32:\n",
87
  " clip32 = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
88
  " processor32 = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
 
101
  " return vqgan.decode_code(indices, params=params)\n",
102
  "\n",
103
  "@partial(jax.pmap, axis_name=\"batch\")\n",
104
+ "def p_clip16(inputs, params):\n",
105
+ " logits = clip16(params=params, **inputs).logits_per_image\n",
106
  " return logits\n",
107
  "\n",
108
  "if add_clip_32:\n",
 
207
  "outputs": [],
208
  "source": [
209
  "run_id = run_ids[0]\n",
210
+ "# TODO: loop over runs"
211
  ]
212
  },
213
  {
 
226
  "for artifact in artifact_versions:\n",
227
  " print(f'Processing artifact: {artifact.name}')\n",
228
  " version = int(artifact.version[1:])\n",
229
+ " results16, results32 = [], []\n",
230
+ " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)]\n",
 
 
231
  " \n",
232
  " if latest_only:\n",
233
  " assert last_inference_version is None or version > last_inference_version\n",
 
264
  "\n",
265
  " # process one batch of captions\n",
266
  " for batch in tqdm(samples):\n",
267
+ " processed_prompts = [text_normalizer(x) for x in batch] if model.config.normalize_text else list(batch)\n",
268
  "\n",
269
  " # repeat the prompts to distribute over each device and tokenize\n",
270
  " processed_prompts = processed_prompts * jax.device_count()\n",
 
273
  "\n",
274
  " # generate images\n",
275
  " images = []\n",
276
+ " pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=True)\n",
277
  " for i in pbar:\n",
278
  " key, subkey = jax.random.split(key)\n",
279
  " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
 
283
  " for img in decoded_images:\n",
284
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))\n",
285
  "\n",
286
+ " def add_clip_results(results, processor, p_clip, clip_params): \n",
287
+ " clip_inputs = processor(text=batch, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  " # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
289
  " images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
290
  " clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
291
  " clip_inputs = shard(clip_inputs)\n",
292
+ " logits = p_clip(clip_inputs, clip_params)\n",
293
  " logits = logits.reshape(-1, num_images)\n",
294
  " top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
295
  " logits = jax.device_get(logits)\n",
 
297
  " for i, (idx, scores, sample) in enumerate(zip(top_scores, logits, batch)):\n",
298
  " if sample == padding_item: continue\n",
299
  " cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
300
+ " top_images = [wandb.Image(cur_images[x], caption=f'Score: {scores[x]:.2f}') for x in idx]\n",
301
+ " results.append([sample] + top_images)\n",
302
+ " \n",
303
+ " # get clip scores\n",
304
+ " pbar.set_description('Calculating CLIP 16 scores')\n",
305
+ " add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
306
+ " \n",
307
+ " # get clip 32 scores\n",
308
+ " if add_clip_32:\n",
309
+ " pbar.set_description('Calculating CLIP 32 scores')\n",
310
+ " add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
311
+ "\n",
312
  " pbar.close()\n",
313
  "\n",
314
+ " \n",
315
+ "\n",
316
  " # log results\n",
317
+ " table = wandb.Table(columns=columns, data=results16)\n",
318
  " run.log({'Samples': table, 'version': version})\n",
319
  " wandb.finish()\n",
320
  " \n",
 
329
  {
330
  "cell_type": "code",
331
  "execution_count": null,
332
+ "id": "415d3f54-7226-43de-9eea-4283a948dc93",
333
  "metadata": {},
334
  "outputs": [],
335
+ "source": []
 
 
 
 
 
336
  }
337
  ],
338
  "metadata": {
{dev → tools}/inference/samples.txt RENAMED
@@ -32,7 +32,9 @@ illustration of an astronaut in a space suit playing guitar
32
  a clown wearing a spacesuit floating in space
33
  a dog playing with a ball
34
  a cat sits on top of an alligator
 
35
  a rat holding a red lightsaber in a white background
 
36
  A unicorn is passing by a rainbow in a field of flowers
37
  an elephant made of carrots
38
  an elephant on a unicycle during a circus
@@ -40,6 +42,7 @@ photography of a penguin watching television
40
  a penguin is walking on the Moon, Earth is in the background
41
  a penguin standing on a tower of books holds onto a rope from a helicopter
42
  rat wearing a crown
 
43
  looking into the sky, 10 airplanes are seen overhead
44
  shelves filled with books and alchemy potion bottles
45
  this is a detailed high-resolution scan of a human brain
@@ -61,7 +64,6 @@ a cartoon of a superhero bear
61
  an illustration of a cute skeleton wearing a blue hoodie
62
  illustration of a baby shark swimming around corals
63
  an illustration of an avocado in a beanie riding a motorcycle
64
- Cartoon of a carrot with big eyes
65
  logo of a robot wearing glasses and reading a book
66
  illustration of a cactus lifting weigths
67
  logo of a cactus lifting weights
@@ -70,11 +72,12 @@ a skeleton with the shape of a spider
70
  a collection of glasses is sitting on a table
71
  a painting of a capybara sitting on a mountain during fall in surrealist style
72
  a pentagonal green clock
73
- a pixel art illustration of an eagle sitting in a field in the afternoon
74
  a small red block sitting on a large green block
75
  a storefront that has the word 'openai' written on it
76
  a tatoo of a black broccoli
77
  a variety of clocks is sitting on a table
 
 
78
  an emoji of a baby fox wearing a blue hat, green gloves, red shirt, and yellow pants
79
  an emoji of a baby penguin wearing a blue hat, blue gloves, red shirt, and green pants
80
  an extreme close-up view of a capybara sitting in a field
@@ -86,10 +89,11 @@ urinals are lined up in a jungle
86
  a muscular banana sitting upright on a bench smoking watching a banana on television, high definition photography
87
  a human face
88
  a person is holding a phone and a waterbottle, running a marathon
 
89
  Young woman riding her bike through the forest
90
  the best soccer team of the world
91
- the best basketball team of the world
92
  the best football team of the world
 
93
  happy, happiness
94
  sad, sadness
95
  the representation of infinity
@@ -105,3 +109,12 @@ an avocado armchair flying into space
105
  a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps
106
  an illustration of an avocado in a christmas sweater staring at its reflection in a mirror
107
  illustration of an avocado armchair getting married to a pineapple
 
 
 
 
 
 
 
 
 
 
32
  a clown wearing a spacesuit floating in space
33
  a dog playing with a ball
34
  a cat sits on top of an alligator
35
+ a very cute cat laying by a big bike
36
  a rat holding a red lightsaber in a white background
37
+ a very cute giraffe making a funny face
38
  A unicorn is passing by a rainbow in a field of flowers
39
  an elephant made of carrots
40
  an elephant on a unicycle during a circus
 
42
  a penguin is walking on the Moon, Earth is in the background
43
  a penguin standing on a tower of books holds onto a rope from a helicopter
44
  rat wearing a crown
45
+ Cartoon of a carrot with big eyes
46
  looking into the sky, 10 airplanes are seen overhead
47
  shelves filled with books and alchemy potion bottles
48
  this is a detailed high-resolution scan of a human brain
 
64
  an illustration of a cute skeleton wearing a blue hoodie
65
  illustration of a baby shark swimming around corals
66
  an illustration of an avocado in a beanie riding a motorcycle
 
67
  logo of a robot wearing glasses and reading a book
68
  illustration of a cactus lifting weigths
69
  logo of a cactus lifting weights
 
72
  a collection of glasses is sitting on a table
73
  a painting of a capybara sitting on a mountain during fall in surrealist style
74
  a pentagonal green clock
 
75
  a small red block sitting on a large green block
76
  a storefront that has the word 'openai' written on it
77
  a tatoo of a black broccoli
78
  a variety of clocks is sitting on a table
79
+ a table has a train model on it with other cars and things
80
+ a pixel art illustration of an eagle sitting in a field in the afternoon
81
  an emoji of a baby fox wearing a blue hat, green gloves, red shirt, and yellow pants
82
  an emoji of a baby penguin wearing a blue hat, blue gloves, red shirt, and green pants
83
  an extreme close-up view of a capybara sitting in a field
 
89
  a muscular banana sitting upright on a bench smoking watching a banana on television, high definition photography
90
  a human face
91
  a person is holding a phone and a waterbottle, running a marathon
92
+ a child eating a birthday cake near some balloons
93
  Young woman riding her bike through the forest
94
  the best soccer team of the world
 
95
  the best football team of the world
96
+ the best basketball team of the world
97
  happy, happiness
98
  sad, sadness
99
  the representation of infinity
 
109
  a cute avocado armchair singing karaoke on stage in front of a crowd of strawberry shaped lamps
110
  an illustration of an avocado in a christmas sweater staring at its reflection in a mirror
111
  illustration of an avocado armchair getting married to a pineapple
112
+ half human half cat
113
+ half human half dog
114
+ half human half pen
115
+ half human half garbage
116
+ half human half avocado
117
+ half human half Eiffel tower
118
+ a propaganda poster for transhumanism
119
+ a propaganda poster for building a space elevator
120
+ a beautiful epic fantasy painting of a space elevator