Spaces:
Running
Running
style: reformat
Browse files
tools/inference/inference_pipeline.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
tools/inference/log_inference_samples.ipynb
CHANGED
@@ -31,11 +31,14 @@
|
|
31 |
"metadata": {},
|
32 |
"outputs": [],
|
33 |
"source": [
|
34 |
-
"run_ids = [
|
35 |
-
"ENTITY, PROJECT =
|
36 |
-
"VQGAN_REPO, VQGAN_COMMIT_ID =
|
37 |
-
"
|
38 |
-
"
|
|
|
|
|
|
|
39 |
"add_clip_32 = False"
|
40 |
]
|
41 |
},
|
@@ -63,8 +66,8 @@
|
|
63 |
"num_images = 128\n",
|
64 |
"top_k = 8\n",
|
65 |
"text_normalizer = TextNormalizer()\n",
|
66 |
-
"padding_item =
|
67 |
-
"seed = random.randint(0, 2**32-1)\n",
|
68 |
"key = jax.random.PRNGKey(seed)\n",
|
69 |
"api = wandb.Api()"
|
70 |
]
|
@@ -100,12 +103,15 @@
|
|
100 |
"def p_decode(indices, params):\n",
|
101 |
" return vqgan.decode_code(indices, params=params)\n",
|
102 |
"\n",
|
|
|
103 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
104 |
"def p_clip16(inputs, params):\n",
|
105 |
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
106 |
" return logits\n",
|
107 |
"\n",
|
|
|
108 |
"if add_clip_32:\n",
|
|
|
109 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
110 |
" def p_clip32(inputs, params):\n",
|
111 |
" logits = clip32(params=params, **inputs).logits_per_image\n",
|
@@ -119,13 +125,13 @@
|
|
119 |
"metadata": {},
|
120 |
"outputs": [],
|
121 |
"source": [
|
122 |
-
"with open(
|
123 |
" samples = [l.strip() for l in f.readlines()]\n",
|
124 |
" # make list multiple of batch_size by adding elements\n",
|
125 |
" samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
|
126 |
" samples.extend(samples_to_add)\n",
|
127 |
" # reshape\n",
|
128 |
-
" samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
|
129 |
]
|
130 |
},
|
131 |
{
|
@@ -138,9 +144,17 @@
|
|
138 |
"def get_artifact_versions(run_id, latest_only=False):\n",
|
139 |
" try:\n",
|
140 |
" if latest_only:\n",
|
141 |
-
" return [
|
|
|
|
|
|
|
|
|
142 |
" else:\n",
|
143 |
-
" return api.artifact_versions(
|
|
|
|
|
|
|
|
|
144 |
" except:\n",
|
145 |
" return []"
|
146 |
]
|
@@ -153,7 +167,7 @@
|
|
153 |
"outputs": [],
|
154 |
"source": [
|
155 |
"def get_training_config(run_id):\n",
|
156 |
-
" training_run = api.run(f
|
157 |
" config = training_run.config\n",
|
158 |
" return config"
|
159 |
]
|
@@ -168,8 +182,8 @@
|
|
168 |
"# retrieve inference run details\n",
|
169 |
"def get_last_inference_version(run_id):\n",
|
170 |
" try:\n",
|
171 |
-
" inference_run = api.run(f
|
172 |
-
" return inference_run.summary.get(
|
173 |
" except:\n",
|
174 |
" return None"
|
175 |
]
|
@@ -183,7 +197,6 @@
|
|
183 |
"source": [
|
184 |
"# compile functions - needed only once per run\n",
|
185 |
"def pmap_model_function(model):\n",
|
186 |
-
" \n",
|
187 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
188 |
" def _generate(tokenized_prompt, key, params):\n",
|
189 |
" return model.generate(\n",
|
@@ -195,7 +208,7 @@
|
|
195 |
" top_k=gen_top_k,\n",
|
196 |
" top_p=gen_top_p\n",
|
197 |
" )\n",
|
198 |
-
"
|
199 |
" return _generate"
|
200 |
]
|
201 |
},
|
@@ -222,13 +235,21 @@
|
|
222 |
"training_config = get_training_config(run_id)\n",
|
223 |
"run = None\n",
|
224 |
"p_generate = None\n",
|
225 |
-
"model_files = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
"for artifact in artifact_versions:\n",
|
227 |
-
" print(f
|
228 |
" version = int(artifact.version[1:])\n",
|
229 |
" results16, results32 = [], []\n",
|
230 |
-
" columns = [
|
231 |
-
"
|
232 |
" if latest_only:\n",
|
233 |
" assert last_inference_version is None or version > last_inference_version\n",
|
234 |
" else:\n",
|
@@ -236,14 +257,23 @@
|
|
236 |
" # we should start from v0\n",
|
237 |
" assert version == 0\n",
|
238 |
" elif version <= last_inference_version:\n",
|
239 |
-
" print(
|
|
|
|
|
240 |
" else:\n",
|
241 |
" # check we are logging the correct version\n",
|
242 |
" assert version == last_inference_version + 1\n",
|
243 |
"\n",
|
244 |
" # start/resume corresponding run\n",
|
245 |
" if run is None:\n",
|
246 |
-
" run = wandb.init(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
"\n",
|
248 |
" # work in temporary directory\n",
|
249 |
" with tempfile.TemporaryDirectory() as tmp:\n",
|
@@ -264,64 +294,109 @@
|
|
264 |
"\n",
|
265 |
" # process one batch of captions\n",
|
266 |
" for batch in tqdm(samples):\n",
|
267 |
-
" processed_prompts =
|
|
|
|
|
|
|
|
|
268 |
"\n",
|
269 |
" # repeat the prompts to distribute over each device and tokenize\n",
|
270 |
" processed_prompts = processed_prompts * jax.device_count()\n",
|
271 |
-
" tokenized_prompt = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
" tokenized_prompt = shard(tokenized_prompt)\n",
|
273 |
"\n",
|
274 |
" # generate images\n",
|
275 |
" images = []\n",
|
276 |
-
" pbar = tqdm(
|
|
|
|
|
|
|
|
|
277 |
" for i in pbar:\n",
|
278 |
" key, subkey = jax.random.split(key)\n",
|
279 |
-
" encoded_images = p_generate(
|
|
|
|
|
280 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
281 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
282 |
-
" decoded_images = decoded_images.clip(0
|
|
|
|
|
283 |
" for img in decoded_images:\n",
|
284 |
-
" images.append(
|
|
|
|
|
285 |
"\n",
|
286 |
-
" def add_clip_results(results, processor, p_clip, clip_params)
|
287 |
-
" clip_inputs = processor(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
289 |
-
" images_per_prompt_indices = np.asarray(
|
290 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
" clip_inputs = shard(clip_inputs)\n",
|
292 |
" logits = p_clip(clip_inputs, clip_params)\n",
|
293 |
" logits = logits.reshape(-1, num_images)\n",
|
294 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
295 |
" logits = jax.device_get(logits)\n",
|
296 |
" # add to results table\n",
|
297 |
-
" for i, (idx, scores, sample) in enumerate(
|
298 |
-
"
|
|
|
|
|
|
|
299 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
300 |
-
" top_images = [
|
|
|
|
|
|
|
301 |
" results.append([sample] + top_images)\n",
|
302 |
-
"
|
303 |
" # get clip scores\n",
|
304 |
-
" pbar.set_description(
|
305 |
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
306 |
-
"
|
307 |
" # get clip 32 scores\n",
|
308 |
" if add_clip_32:\n",
|
309 |
-
" pbar.set_description(
|
310 |
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
311 |
"\n",
|
312 |
" pbar.close()\n",
|
313 |
"\n",
|
314 |
-
" \n",
|
315 |
-
"\n",
|
316 |
" # log results\n",
|
317 |
" table = wandb.Table(columns=columns, data=results16)\n",
|
318 |
-
" run.log({
|
319 |
" wandb.finish()\n",
|
320 |
-
"
|
321 |
-
" if add_clip_32
|
322 |
-
" run = wandb.init(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
" table = wandb.Table(columns=columns, data=results32)\n",
|
324 |
-
" run.log({
|
325 |
" wandb.finish()\n",
|
326 |
" run = None # ensure we don't log on this run"
|
327 |
]
|
|
|
31 |
"metadata": {},
|
32 |
"outputs": [],
|
33 |
"source": [
|
34 |
+
"run_ids = [\"63otg87g\"]\n",
|
35 |
+
"ENTITY, PROJECT = \"dalle-mini\", \"dalle-mini\" # used only for training run\n",
|
36 |
+
"VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
|
37 |
+
" \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
|
38 |
+
" \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\",\n",
|
39 |
+
")\n",
|
40 |
+
"latest_only = True # log only latest or all versions\n",
|
41 |
+
"suffix = \"\" # mainly for duplicate inference runs with a deleted version\n",
|
42 |
"add_clip_32 = False"
|
43 |
]
|
44 |
},
|
|
|
66 |
"num_images = 128\n",
|
67 |
"top_k = 8\n",
|
68 |
"text_normalizer = TextNormalizer()\n",
|
69 |
+
"padding_item = \"NONE\"\n",
|
70 |
+
"seed = random.randint(0, 2 ** 32 - 1)\n",
|
71 |
"key = jax.random.PRNGKey(seed)\n",
|
72 |
"api = wandb.Api()"
|
73 |
]
|
|
|
103 |
"def p_decode(indices, params):\n",
|
104 |
" return vqgan.decode_code(indices, params=params)\n",
|
105 |
"\n",
|
106 |
+
"\n",
|
107 |
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
108 |
"def p_clip16(inputs, params):\n",
|
109 |
" logits = clip16(params=params, **inputs).logits_per_image\n",
|
110 |
" return logits\n",
|
111 |
"\n",
|
112 |
+
"\n",
|
113 |
"if add_clip_32:\n",
|
114 |
+
"\n",
|
115 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
116 |
" def p_clip32(inputs, params):\n",
|
117 |
" logits = clip32(params=params, **inputs).logits_per_image\n",
|
|
|
125 |
"metadata": {},
|
126 |
"outputs": [],
|
127 |
"source": [
|
128 |
+
"with open(\"samples.txt\", encoding=\"utf8\") as f:\n",
|
129 |
" samples = [l.strip() for l in f.readlines()]\n",
|
130 |
" # make list multiple of batch_size by adding elements\n",
|
131 |
" samples_to_add = [padding_item] * (-len(samples) % batch_size)\n",
|
132 |
" samples.extend(samples_to_add)\n",
|
133 |
" # reshape\n",
|
134 |
+
" samples = [samples[i : i + batch_size] for i in range(0, len(samples), batch_size)]"
|
135 |
]
|
136 |
},
|
137 |
{
|
|
|
144 |
"def get_artifact_versions(run_id, latest_only=False):\n",
|
145 |
" try:\n",
|
146 |
" if latest_only:\n",
|
147 |
+
" return [\n",
|
148 |
+
" api.artifact(\n",
|
149 |
+
" type=\"bart_model\", name=f\"{ENTITY}/{PROJECT}/model-{run_id}:latest\"\n",
|
150 |
+
" )\n",
|
151 |
+
" ]\n",
|
152 |
" else:\n",
|
153 |
+
" return api.artifact_versions(\n",
|
154 |
+
" type_name=\"bart_model\",\n",
|
155 |
+
" name=f\"{ENTITY}/{PROJECT}/model-{run_id}\",\n",
|
156 |
+
" per_page=10000,\n",
|
157 |
+
" )\n",
|
158 |
" except:\n",
|
159 |
" return []"
|
160 |
]
|
|
|
167 |
"outputs": [],
|
168 |
"source": [
|
169 |
"def get_training_config(run_id):\n",
|
170 |
+
" training_run = api.run(f\"{ENTITY}/{PROJECT}/{run_id}\")\n",
|
171 |
" config = training_run.config\n",
|
172 |
" return config"
|
173 |
]
|
|
|
182 |
"# retrieve inference run details\n",
|
183 |
"def get_last_inference_version(run_id):\n",
|
184 |
" try:\n",
|
185 |
+
" inference_run = api.run(f\"dalle-mini/dalle-mini/{run_id}-clip16{suffix}\")\n",
|
186 |
+
" return inference_run.summary.get(\"version\", None)\n",
|
187 |
" except:\n",
|
188 |
" return None"
|
189 |
]
|
|
|
197 |
"source": [
|
198 |
"# compile functions - needed only once per run\n",
|
199 |
"def pmap_model_function(model):\n",
|
|
|
200 |
" @partial(jax.pmap, axis_name=\"batch\")\n",
|
201 |
" def _generate(tokenized_prompt, key, params):\n",
|
202 |
" return model.generate(\n",
|
|
|
208 |
" top_k=gen_top_k,\n",
|
209 |
" top_p=gen_top_p\n",
|
210 |
" )\n",
|
211 |
+
"\n",
|
212 |
" return _generate"
|
213 |
]
|
214 |
},
|
|
|
235 |
"training_config = get_training_config(run_id)\n",
|
236 |
"run = None\n",
|
237 |
"p_generate = None\n",
|
238 |
+
"model_files = [\n",
|
239 |
+
" \"config.json\",\n",
|
240 |
+
" \"flax_model.msgpack\",\n",
|
241 |
+
" \"merges.txt\",\n",
|
242 |
+
" \"special_tokens_map.json\",\n",
|
243 |
+
" \"tokenizer.json\",\n",
|
244 |
+
" \"tokenizer_config.json\",\n",
|
245 |
+
" \"vocab.json\",\n",
|
246 |
+
"]\n",
|
247 |
"for artifact in artifact_versions:\n",
|
248 |
+
" print(f\"Processing artifact: {artifact.name}\")\n",
|
249 |
" version = int(artifact.version[1:])\n",
|
250 |
" results16, results32 = [], []\n",
|
251 |
+
" columns = [\"Caption\"] + [f\"Image {i+1}\" for i in range(top_k)]\n",
|
252 |
+
"\n",
|
253 |
" if latest_only:\n",
|
254 |
" assert last_inference_version is None or version > last_inference_version\n",
|
255 |
" else:\n",
|
|
|
257 |
" # we should start from v0\n",
|
258 |
" assert version == 0\n",
|
259 |
" elif version <= last_inference_version:\n",
|
260 |
+
" print(\n",
|
261 |
+
" f\"v{version} has already been logged (versions logged up to v{last_inference_version}\"\n",
|
262 |
+
" )\n",
|
263 |
" else:\n",
|
264 |
" # check we are logging the correct version\n",
|
265 |
" assert version == last_inference_version + 1\n",
|
266 |
"\n",
|
267 |
" # start/resume corresponding run\n",
|
268 |
" if run is None:\n",
|
269 |
+
" run = wandb.init(\n",
|
270 |
+
" job_type=\"inference\",\n",
|
271 |
+
" entity=\"dalle-mini\",\n",
|
272 |
+
" project=\"dalle-mini\",\n",
|
273 |
+
" config=training_config,\n",
|
274 |
+
" id=f\"{run_id}-clip16{suffix}\",\n",
|
275 |
+
" resume=\"allow\",\n",
|
276 |
+
" )\n",
|
277 |
"\n",
|
278 |
" # work in temporary directory\n",
|
279 |
" with tempfile.TemporaryDirectory() as tmp:\n",
|
|
|
294 |
"\n",
|
295 |
" # process one batch of captions\n",
|
296 |
" for batch in tqdm(samples):\n",
|
297 |
+
" processed_prompts = (\n",
|
298 |
+
" [text_normalizer(x) for x in batch]\n",
|
299 |
+
" if model.config.normalize_text\n",
|
300 |
+
" else list(batch)\n",
|
301 |
+
" )\n",
|
302 |
"\n",
|
303 |
" # repeat the prompts to distribute over each device and tokenize\n",
|
304 |
" processed_prompts = processed_prompts * jax.device_count()\n",
|
305 |
+
" tokenized_prompt = tokenizer(\n",
|
306 |
+
" processed_prompts,\n",
|
307 |
+
" return_tensors=\"jax\",\n",
|
308 |
+
" padding=\"max_length\",\n",
|
309 |
+
" truncation=True,\n",
|
310 |
+
" max_length=128,\n",
|
311 |
+
" ).data\n",
|
312 |
" tokenized_prompt = shard(tokenized_prompt)\n",
|
313 |
"\n",
|
314 |
" # generate images\n",
|
315 |
" images = []\n",
|
316 |
+
" pbar = tqdm(\n",
|
317 |
+
" range(num_images // jax.device_count()),\n",
|
318 |
+
" desc=\"Generating Images\",\n",
|
319 |
+
" leave=True,\n",
|
320 |
+
" )\n",
|
321 |
" for i in pbar:\n",
|
322 |
" key, subkey = jax.random.split(key)\n",
|
323 |
+
" encoded_images = p_generate(\n",
|
324 |
+
" tokenized_prompt, shard_prng_key(subkey), model_params\n",
|
325 |
+
" )\n",
|
326 |
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
327 |
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
328 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape(\n",
|
329 |
+
" (-1, 256, 256, 3)\n",
|
330 |
+
" )\n",
|
331 |
" for img in decoded_images:\n",
|
332 |
+
" images.append(\n",
|
333 |
+
" Image.fromarray(np.asarray(img * 255, dtype=np.uint8))\n",
|
334 |
+
" )\n",
|
335 |
"\n",
|
336 |
+
" def add_clip_results(results, processor, p_clip, clip_params):\n",
|
337 |
+
" clip_inputs = processor(\n",
|
338 |
+
" text=batch,\n",
|
339 |
+
" images=images,\n",
|
340 |
+
" return_tensors=\"np\",\n",
|
341 |
+
" padding=\"max_length\",\n",
|
342 |
+
" max_length=77,\n",
|
343 |
+
" truncation=True,\n",
|
344 |
+
" ).data\n",
|
345 |
" # each shard will have one prompt, images need to be reorganized to be associated to the correct shard\n",
|
346 |
+
" images_per_prompt_indices = np.asarray(\n",
|
347 |
+
" range(0, len(images), batch_size)\n",
|
348 |
+
" )\n",
|
349 |
+
" clip_inputs[\"pixel_values\"] = jnp.concatenate(\n",
|
350 |
+
" list(\n",
|
351 |
+
" clip_inputs[\"pixel_values\"][images_per_prompt_indices + i]\n",
|
352 |
+
" for i in range(batch_size)\n",
|
353 |
+
" )\n",
|
354 |
+
" )\n",
|
355 |
" clip_inputs = shard(clip_inputs)\n",
|
356 |
" logits = p_clip(clip_inputs, clip_params)\n",
|
357 |
" logits = logits.reshape(-1, num_images)\n",
|
358 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
359 |
" logits = jax.device_get(logits)\n",
|
360 |
" # add to results table\n",
|
361 |
+
" for i, (idx, scores, sample) in enumerate(\n",
|
362 |
+
" zip(top_scores, logits, batch)\n",
|
363 |
+
" ):\n",
|
364 |
+
" if sample == padding_item:\n",
|
365 |
+
" continue\n",
|
366 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
367 |
+
" top_images = [\n",
|
368 |
+
" wandb.Image(cur_images[x], caption=f\"Score: {scores[x]:.2f}\")\n",
|
369 |
+
" for x in idx\n",
|
370 |
+
" ]\n",
|
371 |
" results.append([sample] + top_images)\n",
|
372 |
+
"\n",
|
373 |
" # get clip scores\n",
|
374 |
+
" pbar.set_description(\"Calculating CLIP 16 scores\")\n",
|
375 |
" add_clip_results(results16, processor16, p_clip16, clip16_params)\n",
|
376 |
+
"\n",
|
377 |
" # get clip 32 scores\n",
|
378 |
" if add_clip_32:\n",
|
379 |
+
" pbar.set_description(\"Calculating CLIP 32 scores\")\n",
|
380 |
" add_clip_results(results32, processor32, p_clip32, clip32_params)\n",
|
381 |
"\n",
|
382 |
" pbar.close()\n",
|
383 |
"\n",
|
|
|
|
|
384 |
" # log results\n",
|
385 |
" table = wandb.Table(columns=columns, data=results16)\n",
|
386 |
+
" run.log({\"Samples\": table, \"version\": version})\n",
|
387 |
" wandb.finish()\n",
|
388 |
+
"\n",
|
389 |
+
" if add_clip_32:\n",
|
390 |
+
" run = wandb.init(\n",
|
391 |
+
" job_type=\"inference\",\n",
|
392 |
+
" entity=\"dalle-mini\",\n",
|
393 |
+
" project=\"dalle-mini\",\n",
|
394 |
+
" config=training_config,\n",
|
395 |
+
" id=f\"{run_id}-clip32{suffix}\",\n",
|
396 |
+
" resume=\"allow\",\n",
|
397 |
+
" )\n",
|
398 |
" table = wandb.Table(columns=columns, data=results32)\n",
|
399 |
+
" run.log({\"Samples\": table, \"version\": version})\n",
|
400 |
" wandb.finish()\n",
|
401 |
" run = None # ensure we don't log on this run"
|
402 |
]
|