Spaces:
Runtime error
Runtime error
File size: 4,827 Bytes
5282eae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import argparse
import base64
import json
import os
import tarfile
import uuid
import zipfile
import time
import braceexpand
import webdataset as wds
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--output_dir", type=str)
arg_parser.add_argument(
"--image_shards",
type=str,
help="Pass in a list of shards in the format path_to_shard/shard_{0..23098}_images_v2.tar",
)
arg_parser.add_argument(
"--doc_shards",
type=str,
help="Pass in a list of shards in the format path_to_shard/docs_shard_{0..23098}_v2.jsonl.zip",
)
arg_parser.add_argument(
"--thread",
type=int,
default=128,
)
args = arg_parser.parse_args()
def get_txt_to_filename_dict(image_shards, disable_tqdm=False):
txt_to_filename_dict = {}
dataset = wds.WebDataset(image_shards).decode("pil").to_tuple("txt", "json")
for data in tqdm(dataset, disable=disable_tqdm):
txt = data[0].split(".")[0]
txt_to_filename_dict[txt] = data[1]['key']
return txt_to_filename_dict
def single_thread(args):
i = args["i"]
output_dir = args["output_dir"]
doc_shards = args["doc_shards"]
image_shards = args["image_shards"]
if i == 0:
tqdm.write(f"output_dir: {output_dir}")
tqdm.write(f"doc_shards: {doc_shards[:5]}")
tqdm.write(f"image_shards: {image_shards[:5]}")
with wds.ShardWriter(os.path.join(output_dir, "%09d.tar"), maxcount=1000) as sink:
sink.verbose = False
for doc_shard, image_shard in tqdm(zip(doc_shards, image_shards), disable=(i != 0), total=len(doc_shards)):
# txt_to_filename_dict = get_txt_to_filename_dict(image_shard, disable_tqdm=(i != 0))
# image_tar = tarfile.open(image_shard)
# Open the ZIP archive and extract the JSON file
with zipfile.ZipFile(doc_shard, "r") as zip_file:
# Assumes the JSON file is the first file in the archive
json_filename = zip_file.namelist()[0]
with zip_file.open(json_filename, "r") as json_file:
pbar = tqdm(json_file, disable=True)
total_num = 0
exist_num = 0
for sample_data in pbar:
# get image names from json
sample_data = json.loads(sample_data)
image_info = sample_data["image_info"]
image_names = [image["image_name"] for image in image_info]
# Add each image to the tar file
for img_idx, image_name in enumerate(image_names):
total_num += 1
try:
image = image_tar.extractfile(txt_to_filename_dict[image_name.split(".")[0]]+".jpg")
# convert to base64
image_bytes = image.read()
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
exist_num += 1
except:
tqdm.write(f"{image_name.split('.')[0]}")
image_base64 = "null"
sample_data["image_info"][img_idx][
"image_base64"
] = image_base64
key_str = uuid.uuid4().hex
sink.write({"__key__": key_str, "json": sample_data})
pbar.set_description(f"{exist_num/total_num:.2f}")
# image_tar.close()
def main():
timestamp = int(time.time())
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, str(timestamp)), exist_ok=True)
tasks = []
for i in range(args.thread):
thread_dir = os.path.join(args.output_dir, str(timestamp), str(i))
os.makedirs(thread_dir, exist_ok=True)
tasks.append({
"i": i,
"output_dir": thread_dir,
"doc_shards": [],
"image_shards": [],
})
doc_shards = list(braceexpand.braceexpand(args.doc_shards))
image_shards = list(braceexpand.braceexpand(args.image_shards))
assert len(doc_shards) == len(
image_shards
), "Each doc shards must have a corresponding image shard"
for i, (doc_shard, image_shard) in enumerate(zip(doc_shards, image_shards)):
tasks[i % args.thread]["doc_shards"].append(doc_shard)
tasks[i % args.thread]["image_shards"].append(image_shard)
# assert len(tasks) == args.thread
# process_map(single_thread, tasks, max_workers=args.thread, disable=True)
single_thread(tasks[0])
if __name__ == "__main__":
main()
|