edugp commited on
Commit
2daf3c7
1 Parent(s): 8a1113b

Add all necessary files to replicate training run

Browse files
README.md CHANGED
@@ -1,5 +1,7 @@
1
  # Download datasets:
2
  * Download and decompress tsv file from here: https://github.com/google-research-datasets/wit/blob/main/DATA.md
3
- * Use `prepare_wit.py` to download images from Wikipedia.
4
- * Use `discard_incorrect_files` to filter out corrupt files.`TODO: Still some corrupt files are being kept.` `TODO: Make it a CLI`.
 
 
5
  * Finally, use `run-clip.sh` to train.
 
1
  # Download datasets:
2
  * Download and decompress tsv file from here: https://github.com/google-research-datasets/wit/blob/main/DATA.md
3
+ * Use `prepare_wit.py` to download images from Wikipedia as annotated on each TSV file.
4
+ * Use `scale_converter.py` to remove corrupt images and resize suitable images to 224x224
5
+ * Use `join_datasets_custom_split.py` to group all JSONs from different subsets of the dataset together
6
+ * Use `discard_incorrect_files.py` to filter out images that we were not able to convert.
7
  * Finally, use `run-clip.sh` to train.
discard_incorrect_files.py CHANGED
@@ -1,23 +1,28 @@
1
  import json
2
  import os
 
3
 
4
  import torch
5
  from torchvision.io import ImageReadMode, read_image
6
 
7
- # SUPPORTED_EXTENSIONS = {'PNG', 'JPG', 'png', 'JPEG', 'jpg', 'jpeg'}
 
8
 
9
  for split in ["train", "valid", "test"]:
10
- with open(f"/home/{os.environ['USER']}/data/wit/prepared_dataset/{split}_dataset.json") as f:
 
11
  examples = [json.loads(line) for line in f.readlines()]
12
-
 
13
  supported_examples = []
14
- for example in examples:
15
- try:
16
- image = read_image(example["image_path"], mode=ImageReadMode.RGB)
 
17
  supported_examples.append(json.dumps(example, ensure_ascii=False))
18
- except Exception as e:
19
- print(f"Excluding file: {example['image_path']} due to error: {e}")
20
 
21
  print(f"Total {split} examples: {len(supported_examples)}")
22
- with open(f"/home/{os.environ['USER']}/data/wit/prepared_dataset/{split}_dataset_filtered.json", "w") as f:
23
  f.write("\n".join(supported_examples))
 
 
 
1
  import json
2
  import os
3
+ from tqdm import tqdm
4
 
5
  import torch
6
  from torchvision.io import ImageReadMode, read_image
7
 
8
+ JOINT_JSON_DIRECTORY = f"/home/{os.environ['USER']}/data/wit/all_jsons"
9
+ SCALE_CONVERTED_DIRECTORY = f"/home/{os.environ['USER']}/data/wit_scale_converted"
10
 
11
  for split in ["train", "valid", "test"]:
12
+ print("Reading json")
13
+ with open(f"{JOINT_JSON_DIRECTORY}/{split}_dataset_all_98_1_1_split.json") as f:
14
  examples = [json.loads(line) for line in f.readlines()]
15
+ valid_files = set(os.listdir(SCALE_CONVERTED_DIRECTORY))
16
+
17
  supported_examples = []
18
+ for example in tqdm(examples):
19
+ directory, filename = os.path.split(example['image_path'])
20
+ if filename in valid_files:
21
+ example["image_path"] = os.path.join(SCALE_CONVERTED_DIRECTORY, filename)
22
  supported_examples.append(json.dumps(example, ensure_ascii=False))
 
 
23
 
24
  print(f"Total {split} examples: {len(supported_examples)}")
25
+ with open(f"{SCALE_CONVERTED_DIRECTORY}/{split}_dataset_scale_converted_98_1_1_split.json", "w") as f:
26
  f.write("\n".join(supported_examples))
27
+
28
+ print("DONE!")
join_datasets_custom_split.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+
5
+ import pandas as pd
6
+
7
+
8
+ DATA_DIR = f"/home/{os.environ['USER']}/data/wit/all_jsons"
9
+ SEED = 0
10
+ PROPORTION_TRAIN = 0.98
11
+ PROPORTION_VALID = 0.01
12
+
13
+ random.seed(SEED)
14
+
15
+ all_files = [f"{DATA_DIR}/{file_}" for file_ in os.listdir(DATA_DIR) if ("all" not in file_)]
16
+
17
+ print(all_files)
18
+
19
+ examples = []
20
+ for file_ in all_files:
21
+ print(file_)
22
+ with open(file_) as f:
23
+ file_examples = [json.dumps(json.loads(line), ensure_ascii=False) for line in f.readlines()]
24
+ print(len(file_examples))
25
+ examples.extend(file_examples)
26
+
27
+ print(f"Before dedup: {len(examples)}")
28
+ examples = list(set(examples))
29
+ print(f"After dedup: {len(examples)}")
30
+
31
+ print(examples[0])
32
+ # Shuffle examples
33
+ random.shuffle(examples)
34
+ print(examples[0])
35
+
36
+ split_dataset = {}
37
+ split_dataset["train"] = examples[:int(len(examples) * PROPORTION_TRAIN)]
38
+ split_dataset["valid"] = examples[int(len(examples) * PROPORTION_TRAIN): int(len(examples) * (PROPORTION_TRAIN + PROPORTION_VALID))]
39
+ split_dataset["test"] = examples[int(len(examples) * (PROPORTION_TRAIN + PROPORTION_VALID)):]
40
+
41
+
42
+ for split in ["train", "valid", "test"]:
43
+ print("-----")
44
+ print(len(split_dataset[split]))
45
+ print("-----")
46
+ with open(f"/home/{os.environ['USER']}/data/wit/all_jsons/{split}_dataset_all_98_1_1_split.json", "w") as f:
47
+ f.write("\n".join(split_dataset[split]))
48
+
prepare_wit.py CHANGED
@@ -33,14 +33,17 @@ def split_and_save_datasets(lines: List[str], output_dir: str, train_proportion:
33
  with open(f"{output_dir}/test_dataset.json", "w") as f:
34
  f.write("\n".join(test_lines))
35
 
36
- def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, backup_period: int, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=0.1, retries: int=5):
 
37
  os.makedirs(output_dir, exist_ok=True)
38
  logger.info("Loading dataset")
39
  df = pd.read_csv(tsv, sep="\t", engine="python")
40
- df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull())]
 
 
41
  # Shuffle
42
  df = df.sample(frac=1.0, random_state=seed)
43
- logger.info("Download started")
44
  lines = []
45
  count = 0
46
  try:
@@ -49,7 +52,7 @@ def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_propo
49
  url = row[url_col]
50
  caption = row[caption_col]
51
  # Trim image file names so that they are no longer than 100 characters
52
- image_filename = url.split('/')[-1][-100:]
53
  image_path = f"{output_dir}/{image_filename}"
54
  for retry in range(retries):
55
  try:
@@ -59,13 +62,12 @@ def prepare_wit(tsv: str, language: str, output_dir: str, seed: int, train_propo
59
  count += 1
60
  break
61
  except urllib.error.HTTPError as e:
62
- # time.sleep(pause)
63
- pass
64
  if count % backup_period == 0:
65
  logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
66
  split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
67
- if retry == retries:
68
- raise ValueError("Rate limit achieved:", e)
69
  pbar.update(1)
70
  # Save existing dataset, even upon failure
71
  finally:
 
33
  with open(f"{output_dir}/test_dataset.json", "w") as f:
34
  f.write("\n".join(test_lines))
35
 
36
+ def prepare_wit(
37
+ tsv: str, language: str, output_dir: str, seed: int, train_proportion: float, valid_proportion: float, backup_period: int, language_col: str="language", caption_col: str="caption_reference_description", url_col: str="image_url", pause=0.875, retries: int=10):
38
  os.makedirs(output_dir, exist_ok=True)
39
  logger.info("Loading dataset")
40
  df = pd.read_csv(tsv, sep="\t", engine="python")
41
+ existing_files = set(os.listdir(output_dir))
42
+ not_exists_condition = (~(df[url_col].map(lambda x: x.split("/")[-1][-100:]).isin(existing_files)))
43
+ df = df[(df["language"] == language) & (~df["caption_reference_description"].isnull()) & not_exists_condition]
44
  # Shuffle
45
  df = df.sample(frac=1.0, random_state=seed)
46
+ logger.info(f"Trying to downloading {df.shape[0]} files")
47
  lines = []
48
  count = 0
49
  try:
 
52
  url = row[url_col]
53
  caption = row[caption_col]
54
  # Trim image file names so that they are no longer than 100 characters
55
+ image_filename = url.split("/")[-1][-100:]
56
  image_path = f"{output_dir}/{image_filename}"
57
  for retry in range(retries):
58
  try:
 
62
  count += 1
63
  break
64
  except urllib.error.HTTPError as e:
65
+ time.sleep(pause * 10)
 
66
  if count % backup_period == 0:
67
  logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
68
  split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
69
+ if retry == retries - 1:
70
+ logger.info(f"Skipping {image_filename}")
71
  pbar.update(1)
72
  # Save existing dataset, even upon failure
73
  finally:
run-clip.sh CHANGED
@@ -1,12 +1,12 @@
1
- HUB_TOKEN=`cat $HOME/.huggingface/token`
2
  python run_hybrid_clip.py \
3
- --output_dir "./output_dir" \
4
  --text_model_name_or_path="dccuchile/bert-base-spanish-wwm-cased" \
5
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
6
  --tokenizer_name="dccuchile/bert-base-spanish-wwm-cased" \
7
- --train_file="/home/${USER}/data/wit/prepared_dataset/train_dataset_filtered.json" \
8
- --validation_file="/home/${USER}/data/wit/prepared_dataset/valid_dataset_filtered.json" \
9
- --do_train --do_eval \
 
10
  --num_train_epochs="40" \
11
  --max_seq_length 96 \
12
  --per_device_train_batch_size="64" \
@@ -14,5 +14,3 @@ python run_hybrid_clip.py \
14
  --learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
15
  --overwrite_output_dir \
16
  --preprocessing_num_workers 32
17
- #--push_to_hub
18
-
 
 
1
  python run_hybrid_clip.py \
2
+ --output_dir "./output_141230_training_examples" \
3
  --text_model_name_or_path="dccuchile/bert-base-spanish-wwm-cased" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
  --tokenizer_name="dccuchile/bert-base-spanish-wwm-cased" \
6
+ --train_file="/home/${USER}/data/wit_scale_converted/train_dataset_scale_converted_98_1_1_split.json" \
7
+ --validation_file="/home/${USER}/data/wit_scale_converted/valid_dataset_scale_converted_98_1_1_split.json" \
8
+ --do_train \
9
+ --do_eval \
10
  --num_train_epochs="40" \
11
  --max_seq_length 96 \
12
  --per_device_train_batch_size="64" \
 
14
  --learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
15
  --overwrite_output_dir \
16
  --preprocessing_num_workers 32
 
 
run_hybrid_clip.py CHANGED
@@ -225,8 +225,9 @@ class ImageTextDataset(VisionDataset):
225
  self.image_paths = []
226
 
227
  for example in examples:
228
- self.captions.extend(example["captions"][:captions_per_image])
229
- self.image_paths.extend([example["image_path"]] * captions_per_image)
 
230
 
231
  def _load_image(self, idx: int):
232
  path = self.image_paths[idx]
 
225
  self.image_paths = []
226
 
227
  for example in examples:
228
+ captions_subset = example["captions"][:captions_per_image]
229
+ self.captions.extend(captions_subset)
230
+ self.image_paths.extend([example["image_path"]] * len(captions_subset))
231
 
232
  def _load_image(self, idx: int):
233
  path = self.image_paths[idx]
scale_convert.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import itertools
3
+ from argparse import ArgumentParser
4
+ from joblib import Parallel, delayed
5
+ import os
6
+ import subprocess
7
+ from collections import Counter
8
+ import shutil
9
+
10
+
11
+ parser = ArgumentParser()
12
+ parser.add_argument("in_dir")
13
+ parser.add_argument("out_dir")
14
+ args = parser.parse_args()
15
+
16
+ os.makedirs(args.out_dir, exist_ok=True)
17
+
18
+ files = itertools.chain(
19
+ glob.iglob(f"{args.in_dir}/*/*.jpg"),
20
+ glob.iglob(f"{args.in_dir}/*/*.JGPG"),
21
+ glob.iglob(f"{args.in_dir}/*/*.jpeg"),
22
+ glob.iglob(f"{args.in_dir}/*/*.JPEG"),
23
+ glob.iglob(f"{args.in_dir}/*/*.png"),
24
+ glob.iglob(f"{args.in_dir}/*/*.PNG"),
25
+ glob.iglob(f"{args.in_dir}/*/*.svg"),
26
+ glob.iglob(f"{args.in_dir}/*/*.SVG"),
27
+ )
28
+
29
+ def process_file(path):
30
+ basename = os.path.basename(path)
31
+ ext = os.path.splitext(basename)[1]
32
+ name = os.path.splitext(basename)[0]
33
+
34
+ dirname = os.path.dirname(path)
35
+ try:
36
+ r = subprocess.run(
37
+ f'convert {path} -resize "224^>" -colorspace RGB -density 1200 {args.out_dir}/{name}.jpg',
38
+ shell=True,
39
+ timeout=10
40
+ )
41
+ rcode = r.returncode
42
+ except subprocess.TimeoutExpired:
43
+ print("conversion timeout expired")
44
+ rcode = -1
45
+
46
+ if rcode == 0:
47
+ os.remove(path)
48
+
49
+ return rcode
50
+
51
+ codes = Parallel(n_jobs=32, prefer="threads", verbose=1)(delayed(process_file)(f) for f in files)
52
+ print(Counter(codes))
53
+
test_on_image.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import jax
2
  import torch
3
  from torchvision.io import ImageReadMode, read_image
@@ -28,7 +30,7 @@ def run_inference(image_path, text):
28
  score = jax.nn.sigmoid(logits)
29
  return score
30
 
31
- image_path = "/home/eduardogonzalezponferrada/data/wit/full_dataset/Casa_de_Cultura_%284%29.JPG"
32
  text = "Patio interior de un edificio"
33
 
34
- print(run_inference(image_path, text))
 
1
+ import os
2
+
3
  import jax
4
  import torch
5
  from torchvision.io import ImageReadMode, read_image
 
30
  score = jax.nn.sigmoid(logits)
31
  return score
32
 
33
+ image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Self_Portrait_by_David_Allan.jpg"
34
  text = "Patio interior de un edificio"
35
 
36
+ print(run_inference(image_path, text))