Spaces:
Configuration error
Configuration error
import os | |
import pickle | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from utils import random_box, random_click | |
class LIDC(Dataset): | |
names = [] | |
images = [] | |
labels = [] | |
series_uid = [] | |
def __init__(self, data_path, transform=None, transform_msk = None, prompt = 'click'): | |
self.prompt = prompt | |
self.transform = transform | |
self.transform_msk = transform_msk | |
max_bytes = 2**31 - 1 | |
data = {} | |
for file in os.listdir(data_path): | |
filename = os.fsdecode(file) | |
if '.pickle' in filename: | |
file_path = data_path + filename | |
bytes_in = bytearray(0) | |
input_size = os.path.getsize(file_path) | |
with open(file_path, 'rb') as f_in: | |
for _ in range(0, input_size, max_bytes): | |
bytes_in += f_in.read(max_bytes) | |
new_data = pickle.loads(bytes_in) | |
data.update(new_data) | |
for key, value in data.items(): | |
self.names.append(key) | |
self.images.append(value['image'].astype(float)) | |
self.labels.append(value['masks']) | |
self.series_uid.append(value['series_uid']) | |
assert (len(self.images) == len(self.labels) == len(self.series_uid)) | |
for img in self.images: | |
assert np.max(img) <= 1 and np.min(img) >= 0 | |
for label in self.labels: | |
assert np.max(label) <= 1 and np.min(label) >= 0 | |
del new_data | |
del data | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, index): | |
point_label = 1 | |
"""Get the images""" | |
img = np.expand_dims(self.images[index], axis=0) | |
name = self.names[index] | |
multi_rater = self.labels[index] | |
# first click is the target most agreement among raters, otherwise, background agreement | |
if self.prompt == 'click': | |
point_label, pt = random_click(np.array(np.mean(np.stack(multi_rater), axis=0)) / 255, point_label) | |
# Convert image (ensure three channels) and multi-rater labels to torch tensors | |
img = torch.from_numpy(img).type(torch.float32) | |
img = img.repeat(3, 1, 1) | |
multi_rater = [torch.from_numpy(single_rater).type(torch.float32) for single_rater in multi_rater] | |
multi_rater = torch.stack(multi_rater, dim=0) | |
multi_rater = multi_rater.unsqueeze(1) | |
if self.prompt == 'box': | |
x_min, x_max, y_min, y_max = random_box(multi_rater) | |
box = [x_min, x_max, y_min, y_max] | |
mask = multi_rater.mean(dim=0) # average | |
image_meta_dict = {'filename_or_obj':name} | |
return { | |
'image':img, | |
'multi_rater': multi_rater, | |
'label': mask, | |
'p_label':point_label, | |
'pt':pt, | |
'box': box, | |
'image_meta_dict':image_meta_dict, | |
} | |