|
import argparse |
|
import json |
|
import logging |
|
import os |
|
import time |
|
import urllib.request |
|
import urllib.error |
|
|
|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=1.0, retries: int=5): |
|
os.makedirs(output_dir, exist_ok=True) |
|
df = pd.read_csv(tsv, sep="\t", engine="python") |
|
df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull())] |
|
|
|
df = df.sample(frac=1.0, random_state=seed) |
|
lines = [] |
|
try: |
|
with tqdm(total=len(df)) as pbar: |
|
for i, row in tqdm(df.iterrows()): |
|
url = row[url_col] |
|
caption = row[caption_col] |
|
|
|
image_filename = url.split('/')[-1][-100:] |
|
image_path = f"{output_dir}/{image_filename}" |
|
for retry in range(retries): |
|
try: |
|
|
|
urllib.request.urlretrieve(url, image_path) |
|
lines.append(json.dumps({"image_path": image_path, "captions": [caption]}, ensure_ascii=False)) |
|
break |
|
except urllib.error.HTTPError as e: |
|
time.sleep(pause) |
|
if retry == retries: |
|
raise ValueError("Rate limit achieved:", e) |
|
pbar.update(1) |
|
|
|
finally: |
|
total_lines = len(lines) |
|
train_lines = lines[:int(total_lines * train_proportion)] |
|
valid_lines = lines[int(total_lines * train_proportion):int(total_lines * (train_proportion + valid_proportion))] |
|
test_lines = lines[int(total_lines * (train_proportion + valid_proportion)):] |
|
|
|
with open(f"{output_dir}/train_dataset.json", "w") as f: |
|
f.write("\n".join(train_lines)) |
|
|
|
with open(f"{output_dir}/valid_dataset.json", "w") as f: |
|
f.write("\n".join(valid_lines)) |
|
|
|
with open(f"{output_dir}/test_dataset.json", "w") as f: |
|
f.write("\n".join(test_lines)) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description = "Download and prepare the WIT dataset") |
|
parser.add_argument("--tsv", type=str, default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv") |
|
parser.add_argument("--language", type=str, default="es") |
|
parser.add_argument("--output_dir", type=str, default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset") |
|
parser.add_argument("--random_seed", type=int, default=0) |
|
parser.add_argument("--train_proportion", type=float, default=0.8) |
|
parser.add_argument("--valid_proportion", type=float, default=0.1) |
|
|
|
args = parser.parse_args() |
|
assert args.train_proportion + args.valid_proportion < 1.0, "The sum of train_proportion and valid_proportion has to be < 1.0" |
|
prepare_wit(args.tsv, args.language, args.output_dir, args.random_seed, args.train_proportion, args.valid_proportion) |
|
|