introvoyz041's picture
Upload folder using huggingface_hub
3f31c34 verified
raw
history blame
3.06 kB
import os
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset
from utils import random_box, random_click
class LIDC(Dataset):
names = []
images = []
labels = []
series_uid = []
def __init__(self, data_path, transform=None, transform_msk = None, prompt = 'click'):
self.prompt = prompt
self.transform = transform
self.transform_msk = transform_msk
max_bytes = 2**31 - 1
data = {}
for file in os.listdir(data_path):
filename = os.fsdecode(file)
if '.pickle' in filename:
file_path = data_path + filename
bytes_in = bytearray(0)
input_size = os.path.getsize(file_path)
with open(file_path, 'rb') as f_in:
for _ in range(0, input_size, max_bytes):
bytes_in += f_in.read(max_bytes)
new_data = pickle.loads(bytes_in)
data.update(new_data)
for key, value in data.items():
self.names.append(key)
self.images.append(value['image'].astype(float))
self.labels.append(value['masks'])
self.series_uid.append(value['series_uid'])
assert (len(self.images) == len(self.labels) == len(self.series_uid))
for img in self.images:
assert np.max(img) <= 1 and np.min(img) >= 0
for label in self.labels:
assert np.max(label) <= 1 and np.min(label) >= 0
del new_data
del data
def __len__(self):
return len(self.images)
def __getitem__(self, index):
point_label = 1
"""Get the images"""
img = np.expand_dims(self.images[index], axis=0)
name = self.names[index]
multi_rater = self.labels[index]
# first click is the target most agreement among raters, otherwise, background agreement
if self.prompt == 'click':
point_label, pt = random_click(np.array(np.mean(np.stack(multi_rater), axis=0)) / 255, point_label)
# Convert image (ensure three channels) and multi-rater labels to torch tensors
img = torch.from_numpy(img).type(torch.float32)
img = img.repeat(3, 1, 1)
multi_rater = [torch.from_numpy(single_rater).type(torch.float32) for single_rater in multi_rater]
multi_rater = torch.stack(multi_rater, dim=0)
multi_rater = multi_rater.unsqueeze(1)
if self.prompt == 'box':
x_min, x_max, y_min, y_max = random_box(multi_rater)
box = [x_min, x_max, y_min, y_max]
mask = multi_rater.mean(dim=0) # average
image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'multi_rater': multi_rater,
'label': mask,
'p_label':point_label,
'pt':pt,
'box': box,
'image_meta_dict':image_meta_dict,
}