shpotes commited on
Commit
b45f0b4
1 Parent(s): 36fdb4d

add baseline

Browse files
src/run_medclip.py → run_medclip.py RENAMED
@@ -28,6 +28,7 @@ import logging
28
  import os
29
  import sys
30
  import time
 
31
  from dataclasses import dataclass, field
32
  from pathlib import Path
33
  from typing import Callable, Optional
@@ -47,9 +48,9 @@ from flax import jax_utils
47
  from flax.jax_utils import unreplicate
48
  from flax.training import train_state
49
  from flax.training.common_utils import get_metrics, shard, shard_prng_key
50
- from modeling_hybrid_clip import FlaxHybridCLIP
51
  from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
52
-
53
 
54
  logger = logging.getLogger(__name__)
55
 
@@ -210,7 +211,6 @@ class ImageTextDataset(VisionDataset):
210
  self,
211
  root: str,
212
  file_path: str,
213
- captions_per_image=2,
214
  transform: Optional[Callable] = None,
215
  target_transform: Optional[Callable] = None,
216
  transforms: Optional[Callable] = None,
@@ -224,15 +224,21 @@ class ImageTextDataset(VisionDataset):
224
  self.image_paths = []
225
 
226
  for example in examples:
227
- self.captions.extend(example["captions"][:captions_per_image])
228
- self.image_paths.extend([example["image_path"]] * captions_per_image)
229
 
230
  def _load_image(self, idx: int):
231
  path = self.image_paths[idx]
232
  return read_image(path, mode=ImageReadMode.RGB)
233
 
234
  def _load_target(self, idx):
235
- return self.captions[idx]
 
 
 
 
 
 
236
 
237
  def __getitem__(self, index: int):
238
  image = self._load_image(index)
@@ -290,6 +296,17 @@ def main():
290
  else:
291
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
292
 
 
 
 
 
 
 
 
 
 
 
 
293
  if (
294
  os.path.exists(training_args.output_dir)
295
  and os.listdir(training_args.output_dir)
@@ -351,14 +368,12 @@ def main():
351
  train_dataset = ImageTextDataset(
352
  data_args.data_dir,
353
  data_args.train_file,
354
- captions_per_image=2,
355
  transform=preprocess,
356
  )
357
 
358
  eval_dataset = ImageTextDataset(
359
  data_args.data_dir,
360
  data_args.validation_file,
361
- captions_per_image=1,
362
  transform=preprocess,
363
  )
364
 
 
28
  import os
29
  import sys
30
  import time
31
+ import getpass
32
  from dataclasses import dataclass, field
33
  from pathlib import Path
34
  from typing import Callable, Optional
 
48
  from flax.jax_utils import unreplicate
49
  from flax.training import train_state
50
  from flax.training.common_utils import get_metrics, shard, shard_prng_key
51
+ from src.modeling_medclip import FlaxHybridCLIP
52
  from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
53
+ import wandb
54
 
55
  logger = logging.getLogger(__name__)
56
 
 
211
  self,
212
  root: str,
213
  file_path: str,
 
214
  transform: Optional[Callable] = None,
215
  target_transform: Optional[Callable] = None,
216
  transforms: Optional[Callable] = None,
 
224
  self.image_paths = []
225
 
226
  for example in examples:
227
+ self.captions.append(example["caption"])
228
+ self.image_paths.append(f'{root}/{example["image_path"]}')
229
 
230
  def _load_image(self, idx: int):
231
  path = self.image_paths[idx]
232
  return read_image(path, mode=ImageReadMode.RGB)
233
 
234
  def _load_target(self, idx):
235
+ sections = self.captions[idx]
236
+ longest_section = max(
237
+ filter(lambda x: isinstance(x, str), sections.values()),
238
+ key=len
239
+ )
240
+
241
+ return longest_section
242
 
243
  def __getitem__(self, index: int):
244
  image = self._load_image(index)
 
296
  else:
297
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
298
 
299
+ if jax.process_index() == 0:
300
+ wandb.init(
301
+ entity=getpass.getuser(),
302
+ project='medclip',
303
+ sync_tensorboard=True
304
+ )
305
+
306
+ wandb.config.update(model_args)
307
+ wandb.config.update(data_args)
308
+ wandb.config.update(training_args)
309
+
310
  if (
311
  os.path.exists(training_args.output_dir)
312
  and os.listdir(training_args.output_dir)
 
368
  train_dataset = ImageTextDataset(
369
  data_args.data_dir,
370
  data_args.train_file,
 
371
  transform=preprocess,
372
  )
373
 
374
  eval_dataset = ImageTextDataset(
375
  data_args.data_dir,
376
  data_args.validation_file,
 
377
  transform=preprocess,
378
  )
379
 
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (128 Bytes). View file
 
src/__pycache__/configuration_medclip.cpython-38.pyc ADDED
Binary file (4.17 kB). View file
 
src/__pycache__/modeling_medclip.cpython-38.pyc ADDED
Binary file (12.9 kB). View file
 
src/__pycache__/run_medclip.cpython-38.pyc ADDED
Binary file (16.8 kB). View file
 
src/modeling_medclip.py CHANGED
@@ -18,7 +18,7 @@ from typing import Optional, Tuple
18
  import flax.linen as nn
19
  import jax
20
  import jax.numpy as jnp
21
- from configuration_hybrid_clip import HybridCLIPConfig
22
  from flax.core.frozen_dict import FrozenDict
23
  from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
24
  from transformers.modeling_flax_utils import FlaxPreTrainedModel
 
18
  import flax.linen as nn
19
  import jax
20
  import jax.numpy as jnp
21
+ from src.configuration_medclip import HybridCLIPConfig
22
  from flax.core.frozen_dict import FrozenDict
23
  from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
24
  from transformers.modeling_flax_utils import FlaxPreTrainedModel
train_model.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_medclip.py \
2
+ --output_dir model \
3
+ --text_model_name_or_path="allenai/scibert_scivocab_uncased" \
4
+ --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
+ --tokenizer_name="allenai/scibert_scivocab_uncased" \
6
+ --data_dir="/home/shared/data/mimic-cxr" \
7
+ --train_file="/home/shared/data/mimic-cxr/train_dataset.json" \
8
+ --validation_file="/home/shared/data/mimic-cxr/validate_dataset.json" \
9
+ --do_train --do_eval \
10
+ --num_train_epochs="40" --max_seq_length 512 \
11
+ --per_device_train_batch_size="64" \
12
+ --per_device_eval_batch_size="64" \
13
+ --learning_rate="5e-5" --warmup_steps="0" --weight_decay 0.1 \
14
+ --overwrite_output_dir \
15
+ --preprocessing_num_workers 32 \