|
import argparse |
|
import json |
|
import logging |
|
import os |
|
import time |
|
import urllib.error |
|
import urllib.request |
|
from typing import List |
|
|
|
import pandas as pd |
|
from tqdm import tqdm |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def split_and_save_datasets( |
|
lines: List[str], output_dir: str, train_proportion: float, valid_proportion: float |
|
): |
|
total_lines = len(lines) |
|
train_lines = lines[: int(total_lines * train_proportion)] |
|
valid_lines = lines[ |
|
int(total_lines * train_proportion) : int( |
|
total_lines * (train_proportion + valid_proportion) |
|
) |
|
] |
|
test_lines = lines[int(total_lines * (train_proportion + valid_proportion)) :] |
|
|
|
with open(f"{output_dir}/train_dataset.json", "w") as f: |
|
f.write("\n".join(train_lines)) |
|
|
|
with open(f"{output_dir}/valid_dataset.json", "w") as f: |
|
f.write("\n".join(valid_lines)) |
|
|
|
with open(f"{output_dir}/test_dataset.json", "w") as f: |
|
f.write("\n".join(test_lines)) |
|
|
|
|
|
def prepare_wit( |
|
tsv: str, |
|
language: str, |
|
output_dir: str, |
|
seed: int, |
|
train_proportion: float, |
|
valid_proportion: float, |
|
backup_period: int, |
|
language_col: str = "language", |
|
caption_col: str = "caption_reference_description", |
|
url_col: str = "image_url", |
|
pause=0.875, |
|
retries: int = 10, |
|
): |
|
os.makedirs(output_dir, exist_ok=True) |
|
logger.info("Loading dataset") |
|
df = pd.read_csv(tsv, sep="\t", engine="python") |
|
existing_files = set(os.listdir(output_dir)) |
|
not_exists_condition = ~( |
|
df[url_col].map(lambda x: x.split("/")[-1][-100:]).isin(existing_files) |
|
) |
|
df = df[ |
|
(df["language"] == language) |
|
& (~df["caption_reference_description"].isnull()) |
|
& not_exists_condition |
|
] |
|
|
|
df = df.sample(frac=1.0, random_state=seed) |
|
logger.info(f"Trying to downloading {df.shape[0]} files") |
|
lines = [] |
|
count = 0 |
|
try: |
|
with tqdm(total=len(df)) as pbar: |
|
for i, row in tqdm(df.iterrows()): |
|
url = row[url_col] |
|
caption = row[caption_col] |
|
|
|
image_filename = url.split("/")[-1][-100:] |
|
image_path = f"{output_dir}/{image_filename}" |
|
for retry in range(retries): |
|
try: |
|
|
|
urllib.request.urlretrieve(url, image_path) |
|
lines.append( |
|
json.dumps( |
|
{"image_path": image_path, "captions": [caption]}, |
|
ensure_ascii=False, |
|
) |
|
) |
|
count += 1 |
|
break |
|
except urllib.error.HTTPError: |
|
time.sleep(pause * 10) |
|
if count % backup_period == 0: |
|
logger.info(f"Saving dataset backup: Number of lines {len(lines)}") |
|
split_and_save_datasets( |
|
lines, output_dir, train_proportion, valid_proportion |
|
) |
|
if retry == retries - 1: |
|
logger.info(f"Skipping {image_filename}") |
|
pbar.update(1) |
|
|
|
finally: |
|
split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Download and prepare the WIT dataset") |
|
parser.add_argument( |
|
"--tsv", |
|
type=str, |
|
default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv", |
|
) |
|
parser.add_argument("--language", type=str, default="es") |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset", |
|
) |
|
parser.add_argument("--random_seed", type=int, default=0) |
|
parser.add_argument("--train_proportion", type=float, default=0.8) |
|
parser.add_argument("--valid_proportion", type=float, default=0.1) |
|
parser.add_argument("--backup_period", type=int, default=1000) |
|
|
|
args = parser.parse_args() |
|
assert ( |
|
args.train_proportion + args.valid_proportion < 1.0 |
|
), "The sum of train_proportion and valid_proportion has to be < 1.0" |
|
prepare_wit( |
|
args.tsv, |
|
args.language, |
|
args.output_dir, |
|
args.random_seed, |
|
args.train_proportion, |
|
args.valid_proportion, |
|
args.backup_period, |
|
) |
|
|