clip-spanish / prepare_wit.py
edugp's picture
Add training scripts and initial model trained on 1% of the data.
8e2b754
raw
history blame
3.26 kB
import argparse
import json
import logging
import os
import time
import urllib.request
import urllib.error
import pandas as pd
from tqdm import tqdm
logger = logging.getLogger(__name__)
def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=1.0, retries: int=5):
os.makedirs(output_dir, exist_ok=True)
df = pd.read_csv(tsv, sep="\t", engine="python")
df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull())]
# Shuffle
df = df.sample(frac=1.0, random_state=seed)
lines = []
try:
with tqdm(total=len(df)) as pbar:
for i, row in tqdm(df.iterrows()):
url = row[url_col]
caption = row[caption_col]
# Trim image file names so that they are no longer than 100 characters
image_filename = url.split('/')[-1][-100:]
image_path = f"{output_dir}/{image_filename}"
for retry in range(retries):
try:
# Download file
urllib.request.urlretrieve(url, image_path)
lines.append(json.dumps({"image_path": image_path, "captions": [caption]}, ensure_ascii=False))
break
except urllib.error.HTTPError as e:
time.sleep(pause)
if retry == retries:
raise ValueError("Rate limit achieved:", e)
pbar.update(1)
# Save existing dataset, even upon failure
finally:
total_lines = len(lines)
train_lines = lines[:int(total_lines * train_proportion)]
valid_lines = lines[int(total_lines * train_proportion):int(total_lines * (train_proportion + valid_proportion))]
test_lines = lines[int(total_lines * (train_proportion + valid_proportion)):]
with open(f"{output_dir}/train_dataset.json", "w") as f:
f.write("\n".join(train_lines))
with open(f"{output_dir}/valid_dataset.json", "w") as f:
f.write("\n".join(valid_lines))
with open(f"{output_dir}/test_dataset.json", "w") as f:
f.write("\n".join(test_lines))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description = "Download and prepare the WIT dataset")
parser.add_argument("--tsv", type=str, default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv")
parser.add_argument("--language", type=str, default="es")
parser.add_argument("--output_dir", type=str, default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset")
parser.add_argument("--random_seed", type=int, default=0)
parser.add_argument("--train_proportion", type=float, default=0.8)
parser.add_argument("--valid_proportion", type=float, default=0.1)
args = parser.parse_args()
assert args.train_proportion + args.valid_proportion < 1.0, "The sum of train_proportion and valid_proportion has to be < 1.0"
prepare_wit(args.tsv, args.language, args.output_dir, args.random_seed, args.train_proportion, args.valid_proportion)