#!/usr/bin/env python # coding: utf-8 from typing import Dict, List1 import argparse import json from functools import partial import pathlib import shutil import re from tqdm import tqdm from PIL import Image import pandas as pd ImageCaptionMap = Dict[str, Dict[str, str]] def _get_image_path(row: pd.Series, root_dir: str = '.') -> str: path = [ root_dir, 'files', f'p{row.subject_id}'[:3], f'p{row.subject_id}', f's{row.study_id}', f'{row.dicom_id}.jpg' ] return '/'.join(path) def _prepare_dataframe( captions: pd.DataFrame, metadata: pd.DataFrame, row: pd.Series ) -> pd.Series: if f's{row.study_id}' in captions.index: row[captions.columns] = ( captions .loc[f's{row.study_id}'] .apply(lambda text: ( re.sub('_+', '_', text) .replace('\n', ' ') .lower().rstrip('.') )) ) if row.dicom_id in metadata.index: row['view_position'] = metadata.loc[row.dicom_id, 'ViewPosition'] return row def copy_image( row: pd.Series, target_path: pathlib.Path, split: str, size: int = 224 ) -> str: target_img_path = target_path / split / f'{row.dicom_id}.jpg' target_img_path = str(target_img_path.resolve()) img = Image.open(row.path) img = img.resize((size, size)) img.save(target_img_path) return target_img_path def generate_dataset( root_dir: pathlib.Path, target_dir: pathlib.Path, split: str = 'validate' ) -> ImageCaptionMap: meta_dir = root_dir / 'metadata' metadata = pd.read_csv(meta_dir / 'mimic-cxr-2.0.0-metadata.csv') df_split = pd.read_csv(meta_dir / 'mimic-cxr-2.0.0-split.csv') captions = pd.read_csv(meta_dir / 'mimic_cxr_sectioned.csv') captions = captions.where(~captions.isna(), '').set_index('study') metadata = metadata.set_index('dicom_id') if split in df_split.split.unique(): current_split = df_split[df_split.split == split] get_abs_path = partial(_get_image_path, root_dir=str(root_dir.resolve())) current_split['path'] = current_split.apply(get_abs_path, axis=1) current_split['view_position'] = '' for col in captions.columns: current_split[col] = '' preprocess_func = partial(_prepare_dataframe, captions, metadata) df = current_split.apply(preprocess_func, axis=1) else: raise ValueError('bad split') image_path_to_caption = {} (target_dir / split).mkdir(exist_ok=True, parents=True) for _, element in tqdm(df.iterrows()): caption = { 'impression': element['impression'], 'findings': element['findings'], 'last_paragraph': element['last_paragraph'], 'comparison': element['comparison'], 'view_position': element['view_position'], } image_path = copy_image(element, target_dir, split) image_path_to_caption[image_path] = caption return image_path_to_caption def dump_dataset(image_path_to_caption: ImageCaptionMap) -> List[str]: lines = [] for image_path, captions in image_path_to_caption.items(): lines.append(json.dumps({ 'image_path': image_path, 'caption': captions, })) return lines if __name__ == '__main__': parser = argparse.ArgumentParser(description='Preprocess MIMIC-CXR dataset') parser.add_argument('--data_dir', description='MIMIC-CXR path') parser.add_argument('--target_dir', description='output path') args = parser.parse_args() data_dir = pathlib.Path(args.data_dir) target_dir = pathlib.Path(args.target_dir) for split in ['test', 'validate', 'train']: image_path_to_caption = generate_dataset(data_dir, target_dir, split) lines = dump_dataset(image_path_to_caption) with open(target_dir / f'{split}_dataset.json', 'w') as f: f.write('\n'.join(lines))