import argparse import json import logging import os import time from typing import List import urllib.request import urllib.error import pandas as pd from tqdm import tqdm logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) def split_and_save_datasets(lines: List[str], output_dir: str, train_proportion: float, valid_proportion: float): total_lines = len(lines) train_lines = lines[:int(total_lines * train_proportion)] valid_lines = lines[int(total_lines * train_proportion):int(total_lines * (train_proportion + valid_proportion))] test_lines = lines[int(total_lines * (train_proportion + valid_proportion)):] with open(f"{output_dir}/train_dataset.json", "w") as f: f.write("\n".join(train_lines)) with open(f"{output_dir}/valid_dataset.json", "w") as f: f.write("\n".join(valid_lines)) with open(f"{output_dir}/test_dataset.json", "w") as f: f.write("\n".join(test_lines)) def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, backup_period: int, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=0.1, retries: int=5): os.makedirs(output_dir, exist_ok=True) logger.info("Loading dataset") df = pd.read_csv(tsv, sep="\t", engine="python") df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull())] # Shuffle df = df.sample(frac=1.0, random_state=seed) logger.info("Download started") lines = [] count = 0 try: with tqdm(total=len(df)) as pbar: for i, row in tqdm(df.iterrows()): url = row[url_col] caption = row[caption_col] # Trim image file names so that they are no longer than 100 characters image_filename = url.split('/')[-1][-100:] image_path = f"{output_dir}/{image_filename}" for retry in range(retries): try: # Download file urllib.request.urlretrieve(url, image_path) lines.append(json.dumps({"image_path": image_path, "captions": [caption]}, ensure_ascii=False)) count += 1 break except urllib.error.HTTPError as e: # time.sleep(pause) pass if count % backup_period == 0: logger.info(f"Saving dataset backup: Number of lines {len(lines)}") split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion) if retry == retries: raise ValueError("Rate limit achieved:", e) pbar.update(1) # Save existing dataset, even upon failure finally: split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion) if __name__ == "__main__": parser = argparse.ArgumentParser(description = "Download and prepare the WIT dataset") parser.add_argument("--tsv", type=str, default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv") parser.add_argument("--language", type=str, default="es") parser.add_argument("--output_dir", type=str, default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset") parser.add_argument("--random_seed", type=int, default=0) parser.add_argument("--train_proportion", type=float, default=0.8) parser.add_argument("--valid_proportion", type=float, default=0.1) parser.add_argument("--backup_period", type=int, default=1000) args = parser.parse_args() assert args.train_proportion + args.valid_proportion < 1.0, "The sum of train_proportion and valid_proportion has to be < 1.0" prepare_wit(args.tsv, args.language, args.output_dir, args.random_seed, args.train_proportion, args.valid_proportion, args.backup_period)