Add all necessary files to replicate training run
Browse files- README.md +4 -2
- discard_incorrect_files.py +14 -9
- join_datasets_custom_split.py +48 -0
- prepare_wit.py +10 -8
- run-clip.sh +5 -7
- run_hybrid_clip.py +3 -2
- scale_convert.py +53 -0
- test_on_image.py +4 -2
README.md
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
# Download datasets:
|
2 |
* Download and decompress tsv file from here: https://github.com/google-research-datasets/wit/blob/main/DATA.md
|
3 |
-
* Use `prepare_wit.py` to download images from Wikipedia.
|
4 |
-
* Use `
|
|
|
|
|
5 |
* Finally, use `run-clip.sh` to train.
|
|
|
1 |
# Download datasets:
|
2 |
* Download and decompress tsv file from here: https://github.com/google-research-datasets/wit/blob/main/DATA.md
|
3 |
+
* Use `prepare_wit.py` to download images from Wikipedia as annotated on each TSV file.
|
4 |
+
* Use `scale_converter.py` to remove corrupt images and resize suitable images to 224x224
|
5 |
+
* Use `join_datasets_custom_split.py` to group all JSONs from different subsets of the dataset together
|
6 |
+
* Use `discard_incorrect_files.py` to filter out images that we were not able to convert.
|
7 |
* Finally, use `run-clip.sh` to train.
|
discard_incorrect_files.py
CHANGED
@@ -1,23 +1,28 @@
|
|
1 |
import json
|
2 |
import os
|
|
|
3 |
|
4 |
import torch
|
5 |
from torchvision.io import ImageReadMode, read_image
|
6 |
|
7 |
-
|
|
|
8 |
|
9 |
for split in ["train", "valid", "test"]:
|
10 |
-
|
|
|
11 |
examples = [json.loads(line) for line in f.readlines()]
|
12 |
-
|
|
|
13 |
supported_examples = []
|
14 |
-
for example in examples:
|
15 |
-
|
16 |
-
|
|
|
17 |
supported_examples.append(json.dumps(example, ensure_ascii=False))
|
18 |
-
except Exception as e:
|
19 |
-
print(f"Excluding file: {example['image_path']} due to error: {e}")
|
20 |
|
21 |
print(f"Total {split} examples: {len(supported_examples)}")
|
22 |
-
with open(f"
|
23 |
f.write("\n".join(supported_examples))
|
|
|
|
|
|
1 |
import json
|
2 |
import os
|
3 |
+
from tqdm import tqdm
|
4 |
|
5 |
import torch
|
6 |
from torchvision.io import ImageReadMode, read_image
|
7 |
|
8 |
+
JOINT_JSON_DIRECTORY = f"/home/{os.environ['USER']}/data/wit/all_jsons"
|
9 |
+
SCALE_CONVERTED_DIRECTORY = f"/home/{os.environ['USER']}/data/wit_scale_converted"
|
10 |
|
11 |
for split in ["train", "valid", "test"]:
|
12 |
+
print("Reading json")
|
13 |
+
with open(f"{JOINT_JSON_DIRECTORY}/{split}_dataset_all_98_1_1_split.json") as f:
|
14 |
examples = [json.loads(line) for line in f.readlines()]
|
15 |
+
valid_files = set(os.listdir(SCALE_CONVERTED_DIRECTORY))
|
16 |
+
|
17 |
supported_examples = []
|
18 |
+
for example in tqdm(examples):
|
19 |
+
directory, filename = os.path.split(example['image_path'])
|
20 |
+
if filename in valid_files:
|
21 |
+
example["image_path"] = os.path.join(SCALE_CONVERTED_DIRECTORY, filename)
|
22 |
supported_examples.append(json.dumps(example, ensure_ascii=False))
|
|
|
|
|
23 |
|
24 |
print(f"Total {split} examples: {len(supported_examples)}")
|
25 |
+
with open(f"{SCALE_CONVERTED_DIRECTORY}/{split}_dataset_scale_converted_98_1_1_split.json", "w") as f:
|
26 |
f.write("\n".join(supported_examples))
|
27 |
+
|
28 |
+
print("DONE!")
|
join_datasets_custom_split.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
|
8 |
+
DATA_DIR = f"/home/{os.environ['USER']}/data/wit/all_jsons"
|
9 |
+
SEED = 0
|
10 |
+
PROPORTION_TRAIN = 0.98
|
11 |
+
PROPORTION_VALID = 0.01
|
12 |
+
|
13 |
+
random.seed(SEED)
|
14 |
+
|
15 |
+
all_files = [f"{DATA_DIR}/{file_}" for file_ in os.listdir(DATA_DIR) if ("all" not in file_)]
|
16 |
+
|
17 |
+
print(all_files)
|
18 |
+
|
19 |
+
examples = []
|
20 |
+
for file_ in all_files:
|
21 |
+
print(file_)
|
22 |
+
with open(file_) as f:
|
23 |
+
file_examples = [json.dumps(json.loads(line), ensure_ascii=False) for line in f.readlines()]
|
24 |
+
print(len(file_examples))
|
25 |
+
examples.extend(file_examples)
|
26 |
+
|
27 |
+
print(f"Before dedup: {len(examples)}")
|
28 |
+
examples = list(set(examples))
|
29 |
+
print(f"After dedup: {len(examples)}")
|
30 |
+
|
31 |
+
print(examples[0])
|
32 |
+
# Shuffle examples
|
33 |
+
random.shuffle(examples)
|
34 |
+
print(examples[0])
|
35 |
+
|
36 |
+
split_dataset = {}
|
37 |
+
split_dataset["train"] = examples[:int(len(examples) * PROPORTION_TRAIN)]
|
38 |
+
split_dataset["valid"] = examples[int(len(examples) * PROPORTION_TRAIN): int(len(examples) * (PROPORTION_TRAIN + PROPORTION_VALID))]
|
39 |
+
split_dataset["test"] = examples[int(len(examples) * (PROPORTION_TRAIN + PROPORTION_VALID)):]
|
40 |
+
|
41 |
+
|
42 |
+
for split in ["train", "valid", "test"]:
|
43 |
+
print("-----")
|
44 |
+
print(len(split_dataset[split]))
|
45 |
+
print("-----")
|
46 |
+
with open(f"/home/{os.environ['USER']}/data/wit/all_jsons/{split}_dataset_all_98_1_1_split.json", "w") as f:
|
47 |
+
f.write("\n".join(split_dataset[split]))
|
48 |
+
|
prepare_wit.py
CHANGED
@@ -33,14 +33,17 @@ def split_and_save_datasets(lines: List[str], output_dir: str, train_proportion:
|
|
33 |
with open(f"{output_dir}/test_dataset.json", "w") as f:
|
34 |
f.write("\n".join(test_lines))
|
35 |
|
36 |
-
def prepare_wit(
|
|
|
37 |
os.makedirs(output_dir, exist_ok=True)
|
38 |
logger.info("Loading dataset")
|
39 |
df = pd.read_csv(tsv, sep="\t", engine="python")
|
40 |
-
|
|
|
|
|
41 |
# Shuffle
|
42 |
df = df.sample(frac=1.0, random_state=seed)
|
43 |
-
logger.info("
|
44 |
lines = []
|
45 |
count = 0
|
46 |
try:
|
@@ -49,7 +52,7 @@ def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_propo
|
|
49 |
url = row[url_col]
|
50 |
caption = row[caption_col]
|
51 |
# Trim image file names so that they are no longer than 100 characters
|
52 |
-
image_filename = url.split(
|
53 |
image_path = f"{output_dir}/{image_filename}"
|
54 |
for retry in range(retries):
|
55 |
try:
|
@@ -59,13 +62,12 @@ def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_propo
|
|
59 |
count += 1
|
60 |
break
|
61 |
except urllib.error.HTTPError as e:
|
62 |
-
|
63 |
-
pass
|
64 |
if count % backup_period == 0:
|
65 |
logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
|
66 |
split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
|
67 |
-
if retry == retries:
|
68 |
-
|
69 |
pbar.update(1)
|
70 |
# Save existing dataset, even upon failure
|
71 |
finally:
|
|
|
33 |
with open(f"{output_dir}/test_dataset.json", "w") as f:
|
34 |
f.write("\n".join(test_lines))
|
35 |
|
36 |
+
def prepare_wit(
|
37 |
+
tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, backup_period: int, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=0.875, retries: int=10):
|
38 |
os.makedirs(output_dir, exist_ok=True)
|
39 |
logger.info("Loading dataset")
|
40 |
df = pd.read_csv(tsv, sep="\t", engine="python")
|
41 |
+
existing_files = set(os.listdir(output_dir))
|
42 |
+
not_exists_condition = (~(df[url_col].map(lambda x: x.split("/")[-1][-100:]).isin(existing_files)))
|
43 |
+
df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull()) & not_exists_condition]
|
44 |
# Shuffle
|
45 |
df = df.sample(frac=1.0, random_state=seed)
|
46 |
+
logger.info(f"Trying to downloading {df.shape[0]} files")
|
47 |
lines = []
|
48 |
count = 0
|
49 |
try:
|
|
|
52 |
url = row[url_col]
|
53 |
caption = row[caption_col]
|
54 |
# Trim image file names so that they are no longer than 100 characters
|
55 |
+
image_filename = url.split("/")[-1][-100:]
|
56 |
image_path = f"{output_dir}/{image_filename}"
|
57 |
for retry in range(retries):
|
58 |
try:
|
|
|
62 |
count += 1
|
63 |
break
|
64 |
except urllib.error.HTTPError as e:
|
65 |
+
time.sleep(pause * 10)
|
|
|
66 |
if count % backup_period == 0:
|
67 |
logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
|
68 |
split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
|
69 |
+
if retry == retries - 1:
|
70 |
+
logger.info(f"Skipping {image_filename}")
|
71 |
pbar.update(1)
|
72 |
# Save existing dataset, even upon failure
|
73 |
finally:
|
run-clip.sh
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
HUB_TOKEN=`cat $HOME/.huggingface/token`
|
2 |
python run_hybrid_clip.py \
|
3 |
-
--output_dir "./
|
4 |
--text_model_name_or_path="dccuchile/bert-base-spanish-wwm-cased" \
|
5 |
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
6 |
--tokenizer_name="dccuchile/bert-base-spanish-wwm-cased" \
|
7 |
-
--train_file="/home/${USER}/data/
|
8 |
-
--validation_file="/home/${USER}/data/
|
9 |
-
--do_train
|
|
|
10 |
--num_train_epochs="40" \
|
11 |
--max_seq_length 96 \
|
12 |
--per_device_train_batch_size="64" \
|
@@ -14,5 +14,3 @@ python run_hybrid_clip.py \
|
|
14 |
--learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
|
15 |
--overwrite_output_dir \
|
16 |
--preprocessing_num_workers 32
|
17 |
-
#--push_to_hub
|
18 |
-
|
|
|
|
|
1 |
python run_hybrid_clip.py \
|
2 |
+
--output_dir "./output_141230_training_examples" \
|
3 |
--text_model_name_or_path="dccuchile/bert-base-spanish-wwm-cased" \
|
4 |
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
5 |
--tokenizer_name="dccuchile/bert-base-spanish-wwm-cased" \
|
6 |
+
--train_file="/home/${USER}/data/wit_scale_converted/train_dataset_scale_converted_98_1_1_split.json" \
|
7 |
+
--validation_file="/home/${USER}/data/wit_scale_converted/valid_dataset_scale_converted_98_1_1_split.json" \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
--num_train_epochs="40" \
|
11 |
--max_seq_length 96 \
|
12 |
--per_device_train_batch_size="64" \
|
|
|
14 |
--learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
|
15 |
--overwrite_output_dir \
|
16 |
--preprocessing_num_workers 32
|
|
|
|
run_hybrid_clip.py
CHANGED
@@ -225,8 +225,9 @@ class ImageTextDataset(VisionDataset):
|
|
225 |
self.image_paths = []
|
226 |
|
227 |
for example in examples:
|
228 |
-
|
229 |
-
self.
|
|
|
230 |
|
231 |
def _load_image(self, idx: int):
|
232 |
path = self.image_paths[idx]
|
|
|
225 |
self.image_paths = []
|
226 |
|
227 |
for example in examples:
|
228 |
+
captions_subset = example["captions"][:captions_per_image]
|
229 |
+
self.captions.extend(captions_subset)
|
230 |
+
self.image_paths.extend([example["image_path"]] * len(captions_subset))
|
231 |
|
232 |
def _load_image(self, idx: int):
|
233 |
path = self.image_paths[idx]
|
scale_convert.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import itertools
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from joblib import Parallel, delayed
|
5 |
+
import os
|
6 |
+
import subprocess
|
7 |
+
from collections import Counter
|
8 |
+
import shutil
|
9 |
+
|
10 |
+
|
11 |
+
parser = ArgumentParser()
|
12 |
+
parser.add_argument("in_dir")
|
13 |
+
parser.add_argument("out_dir")
|
14 |
+
args = parser.parse_args()
|
15 |
+
|
16 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
17 |
+
|
18 |
+
files = itertools.chain(
|
19 |
+
glob.iglob(f"{args.in_dir}/*/*.jpg"),
|
20 |
+
glob.iglob(f"{args.in_dir}/*/*.JGPG"),
|
21 |
+
glob.iglob(f"{args.in_dir}/*/*.jpeg"),
|
22 |
+
glob.iglob(f"{args.in_dir}/*/*.JPEG"),
|
23 |
+
glob.iglob(f"{args.in_dir}/*/*.png"),
|
24 |
+
glob.iglob(f"{args.in_dir}/*/*.PNG"),
|
25 |
+
glob.iglob(f"{args.in_dir}/*/*.svg"),
|
26 |
+
glob.iglob(f"{args.in_dir}/*/*.SVG"),
|
27 |
+
)
|
28 |
+
|
29 |
+
def process_file(path):
|
30 |
+
basename = os.path.basename(path)
|
31 |
+
ext = os.path.splitext(basename)[1]
|
32 |
+
name = os.path.splitext(basename)[0]
|
33 |
+
|
34 |
+
dirname = os.path.dirname(path)
|
35 |
+
try:
|
36 |
+
r = subprocess.run(
|
37 |
+
f'convert {path} -resize "224^>" -colorspace RGB -density 1200 {args.out_dir}/{name}.jpg',
|
38 |
+
shell=True,
|
39 |
+
timeout=10
|
40 |
+
)
|
41 |
+
rcode = r.returncode
|
42 |
+
except subprocess.TimeoutExpired:
|
43 |
+
print("conversion timeout expired")
|
44 |
+
rcode = -1
|
45 |
+
|
46 |
+
if rcode == 0:
|
47 |
+
os.remove(path)
|
48 |
+
|
49 |
+
return rcode
|
50 |
+
|
51 |
+
codes = Parallel(n_jobs=32, prefer="threads", verbose=1)(delayed(process_file)(f) for f in files)
|
52 |
+
print(Counter(codes))
|
53 |
+
|
test_on_image.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import jax
|
2 |
import torch
|
3 |
from torchvision.io import ImageReadMode, read_image
|
@@ -28,7 +30,7 @@ def run_inference(image_path, text):
|
|
28 |
score = jax.nn.sigmoid(logits)
|
29 |
return score
|
30 |
|
31 |
-
image_path = "/home/
|
32 |
text = "Patio interior de un edificio"
|
33 |
|
34 |
-
print(run_inference(image_path, text))
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
import jax
|
4 |
import torch
|
5 |
from torchvision.io import ImageReadMode, read_image
|
|
|
30 |
score = jax.nn.sigmoid(logits)
|
31 |
return score
|
32 |
|
33 |
+
image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Self_Portrait_by_David_Allan.jpg"
|
34 |
text = "Patio interior de un edificio"
|
35 |
|
36 |
+
print(run_inference(image_path, text))
|