File size: 3,262 Bytes
8e2b754 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import argparse
import json
import logging
import os
import time
import urllib.request
import urllib.error
import pandas as pd
from tqdm import tqdm
logger = logging.getLogger(__name__)
def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=1.0, retries: int=5):
os.makedirs(output_dir, exist_ok=True)
df = pd.read_csv(tsv, sep="\t", engine="python")
df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull())]
# Shuffle
df = df.sample(frac=1.0, random_state=seed)
lines = []
try:
with tqdm(total=len(df)) as pbar:
for i, row in tqdm(df.iterrows()):
url = row[url_col]
caption = row[caption_col]
# Trim image file names so that they are no longer than 100 characters
image_filename = url.split('/')[-1][-100:]
image_path = f"{output_dir}/{image_filename}"
for retry in range(retries):
try:
# Download file
urllib.request.urlretrieve(url, image_path)
lines.append(json.dumps({"image_path": image_path, "captions": [caption]}, ensure_ascii=False))
break
except urllib.error.HTTPError as e:
time.sleep(pause)
if retry == retries:
raise ValueError("Rate limit achieved:", e)
pbar.update(1)
# Save existing dataset, even upon failure
finally:
total_lines = len(lines)
train_lines = lines[:int(total_lines * train_proportion)]
valid_lines = lines[int(total_lines * train_proportion):int(total_lines * (train_proportion + valid_proportion))]
test_lines = lines[int(total_lines * (train_proportion + valid_proportion)):]
with open(f"{output_dir}/train_dataset.json", "w") as f:
f.write("\n".join(train_lines))
with open(f"{output_dir}/valid_dataset.json", "w") as f:
f.write("\n".join(valid_lines))
with open(f"{output_dir}/test_dataset.json", "w") as f:
f.write("\n".join(test_lines))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = "Download and prepare the WIT dataset")
parser.add_argument("--tsv", type=str, default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv")
parser.add_argument("--language", type=str, default="es")
parser.add_argument("--output_dir", type=str, default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset")
parser.add_argument("--random_seed", type=int, default=0)
parser.add_argument("--train_proportion", type=float, default=0.8)
parser.add_argument("--valid_proportion", type=float, default=0.1)
args = parser.parse_args()
assert args.train_proportion + args.valid_proportion < 1.0, "The sum of train_proportion and valid_proportion has to be < 1.0"
prepare_wit(args.tsv, args.language, args.output_dir, args.random_seed, args.train_proportion, args.valid_proportion)
|