medclip / prepare_data.py
santiago
feat: add data preprocessing pipeline
09ad451
raw
history blame
4.1 kB
#!/usr/bin/env python
# coding: utf-8
from typing import Dict, List1
import argparse
import json
from functools import partial
import pathlib
import shutil
import re
from tqdm import tqdm
from PIL import Image
import pandas as pd
ImageCaptionMap = Dict[str, Dict[str, str]]
def _get_image_path(row: pd.Series, root_dir: str = '.') -> str:
path = [
root_dir,
'files',
f'p{row.subject_id}'[:3],
f'p{row.subject_id}',
f's{row.study_id}',
f'{row.dicom_id}.jpg'
]
return '/'.join(path)
def _prepare_dataframe(
captions: pd.DataFrame,
metadata: pd.DataFrame,
row: pd.Series
) -> pd.Series:
if f's{row.study_id}' in captions.index:
row[captions.columns] = (
captions
.loc[f's{row.study_id}']
.apply(lambda text: (
re.sub('_+', '_', text)
.replace('\n', ' ')
.lower().rstrip('.')
))
)
if row.dicom_id in metadata.index:
row['view_position'] = metadata.loc[row.dicom_id, 'ViewPosition']
return row
def copy_image(
row: pd.Series,
target_path: pathlib.Path,
split: str,
size: int = 224
) -> str:
target_img_path = target_path / split / f'{row.dicom_id}.jpg'
target_img_path = str(target_img_path.resolve())
img = Image.open(row.path)
img = img.resize((size, size))
img.save(target_img_path)
return target_img_path
def generate_dataset(
root_dir: pathlib.Path,
target_dir: pathlib.Path,
split: str = 'validate'
) -> ImageCaptionMap:
meta_dir = root_dir / 'metadata'
metadata = pd.read_csv(meta_dir / 'mimic-cxr-2.0.0-metadata.csv')
df_split = pd.read_csv(meta_dir / 'mimic-cxr-2.0.0-split.csv')
captions = pd.read_csv(meta_dir / 'mimic_cxr_sectioned.csv')
captions = captions.where(~captions.isna(), '').set_index('study')
metadata = metadata.set_index('dicom_id')
if split in df_split.split.unique():
current_split = df_split[df_split.split == split]
get_abs_path = partial(_get_image_path, root_dir=str(root_dir.resolve()))
current_split['path'] = current_split.apply(get_abs_path, axis=1)
current_split['view_position'] = ''
for col in captions.columns:
current_split[col] = ''
preprocess_func = partial(_prepare_dataframe, captions, metadata)
df = current_split.apply(preprocess_func, axis=1)
else:
raise ValueError('bad split')
image_path_to_caption = {}
(target_dir / split).mkdir(exist_ok=True, parents=True)
for _, element in tqdm(df.iterrows()):
caption = {
'impression': element['impression'],
'findings': element['findings'],
'last_paragraph': element['last_paragraph'],
'comparison': element['comparison'],
'view_position': element['view_position'],
}
image_path = copy_image(element, target_dir, split)
image_path_to_caption[image_path] = caption
return image_path_to_caption
def dump_dataset(image_path_to_caption: ImageCaptionMap) -> List[str]:
lines = []
for image_path, captions in image_path_to_caption.items():
lines.append(json.dumps({
'image_path': image_path,
'caption': captions,
}))
return lines
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Preprocess MIMIC-CXR dataset')
parser.add_argument('--data_dir', description='MIMIC-CXR path')
parser.add_argument('--target_dir', description='output path')
args = parser.parse_args()
data_dir = pathlib.Path(args.data_dir)
target_dir = pathlib.Path(args.target_dir)
for split in ['test', 'validate', 'train']:
image_path_to_caption = generate_dataset(data_dir, target_dir, split)
lines = dump_dataset(image_path_to_caption)
with open(target_dir / f'{split}_dataset.json', 'w') as f:
f.write('\n'.join(lines))