|
|
|
|
|
|
|
from typing import Dict, List1 |
|
|
|
import argparse |
|
import json |
|
from functools import partial |
|
import pathlib |
|
import shutil |
|
import re |
|
|
|
from tqdm import tqdm |
|
from PIL import Image |
|
import pandas as pd |
|
|
|
ImageCaptionMap = Dict[str, Dict[str, str]] |
|
|
|
def _get_image_path(row: pd.Series, root_dir: str = '.') -> str: |
|
path = [ |
|
root_dir, |
|
'files', |
|
f'p{row.subject_id}'[:3], |
|
f'p{row.subject_id}', |
|
f's{row.study_id}', |
|
f'{row.dicom_id}.jpg' |
|
] |
|
|
|
return '/'.join(path) |
|
|
|
def _prepare_dataframe( |
|
captions: pd.DataFrame, |
|
metadata: pd.DataFrame, |
|
row: pd.Series |
|
) -> pd.Series: |
|
if f's{row.study_id}' in captions.index: |
|
row[captions.columns] = ( |
|
captions |
|
.loc[f's{row.study_id}'] |
|
.apply(lambda text: ( |
|
re.sub('_+', '_', text) |
|
.replace('\n', ' ') |
|
.lower().rstrip('.') |
|
)) |
|
) |
|
|
|
if row.dicom_id in metadata.index: |
|
row['view_position'] = metadata.loc[row.dicom_id, 'ViewPosition'] |
|
|
|
return row |
|
|
|
def copy_image( |
|
row: pd.Series, |
|
target_path: pathlib.Path, |
|
split: str, |
|
size: int = 224 |
|
) -> str: |
|
target_img_path = target_path / split / f'{row.dicom_id}.jpg' |
|
target_img_path = str(target_img_path.resolve()) |
|
|
|
img = Image.open(row.path) |
|
img = img.resize((size, size)) |
|
img.save(target_img_path) |
|
|
|
return target_img_path |
|
|
|
def generate_dataset( |
|
root_dir: pathlib.Path, |
|
target_dir: pathlib.Path, |
|
split: str = 'validate' |
|
) -> ImageCaptionMap: |
|
meta_dir = root_dir / 'metadata' |
|
|
|
metadata = pd.read_csv(meta_dir / 'mimic-cxr-2.0.0-metadata.csv') |
|
df_split = pd.read_csv(meta_dir / 'mimic-cxr-2.0.0-split.csv') |
|
captions = pd.read_csv(meta_dir / 'mimic_cxr_sectioned.csv') |
|
|
|
captions = captions.where(~captions.isna(), '').set_index('study') |
|
metadata = metadata.set_index('dicom_id') |
|
|
|
if split in df_split.split.unique(): |
|
current_split = df_split[df_split.split == split] |
|
get_abs_path = partial(_get_image_path, root_dir=str(root_dir.resolve())) |
|
|
|
current_split['path'] = current_split.apply(get_abs_path, axis=1) |
|
current_split['view_position'] = '' |
|
for col in captions.columns: |
|
current_split[col] = '' |
|
|
|
preprocess_func = partial(_prepare_dataframe, captions, metadata) |
|
|
|
df = current_split.apply(preprocess_func, axis=1) |
|
|
|
else: |
|
raise ValueError('bad split') |
|
|
|
image_path_to_caption = {} |
|
(target_dir / split).mkdir(exist_ok=True, parents=True) |
|
|
|
for _, element in tqdm(df.iterrows()): |
|
caption = { |
|
'impression': element['impression'], |
|
'findings': element['findings'], |
|
'last_paragraph': element['last_paragraph'], |
|
'comparison': element['comparison'], |
|
'view_position': element['view_position'], |
|
} |
|
|
|
image_path = copy_image(element, target_dir, split) |
|
|
|
image_path_to_caption[image_path] = caption |
|
|
|
return image_path_to_caption |
|
|
|
def dump_dataset(image_path_to_caption: ImageCaptionMap) -> List[str]: |
|
lines = [] |
|
|
|
for image_path, captions in image_path_to_caption.items(): |
|
lines.append(json.dumps({ |
|
'image_path': image_path, |
|
'caption': captions, |
|
})) |
|
|
|
return lines |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='Preprocess MIMIC-CXR dataset') |
|
parser.add_argument('--data_dir', description='MIMIC-CXR path') |
|
parser.add_argument('--target_dir', description='output path') |
|
|
|
args = parser.parse_args() |
|
|
|
data_dir = pathlib.Path(args.data_dir) |
|
target_dir = pathlib.Path(args.target_dir) |
|
|
|
for split in ['test', 'validate', 'train']: |
|
image_path_to_caption = generate_dataset(data_dir, target_dir, split) |
|
lines = dump_dataset(image_path_to_caption) |
|
|
|
with open(target_dir / f'{split}_dataset.json', 'w') as f: |
|
f.write('\n'.join(lines)) |
|
|