introvoyz041's picture
Upload folder using huggingface_hub
3f31c34 verified
import json
import os
import pickle
import nibabel as nib
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset
from utils import generate_click_prompt, random_box, random_click
class Atlas(Dataset):
def __init__(self, args, data_path , transform = None, transform_msk = None, mode = 'Training',prompt = 'click', plane = False):
self.args = args
self.data_path = os.path.join(data_path,'train')
with open(os.path.join(self.data_path,'dataset.json'),'r') as file:
data = json.load(file)
self.name_list = data['training']
self.mode = mode
self.prompt = prompt
self.img_size = args.image_size
self.transform = transform
self.transform_msk = transform_msk
def __len__(self):
return len(self.name_list)
def __getitem__(self, index):
point_label = 1
label = 1
"""Get the images"""
img_name = self.name_list[index]['image']
mask_name = self.name_list[index]['label']
img = nib.load(os.path.join(self.data_path,img_name)).get_fdata()
mask = nib.load(os.path.join(self.data_path,mask_name)).get_fdata()
mask[mask!=label] = 0
mask[mask==label] = 1
# if self.mode == 'Training':
# label = 0 if self.label_list[index] == 'benign' else 1
# else:
# label = int(self.label_list[index])
img = np.transpose(img,(1,2,0))
mask = np.transpose(mask,(1,2,0))
# img = np.resize(mask,(self.args.image_size, self.args.image_size,128))
# mask = np.resize(mask,(self.args.out_size,self.args.out_size,128))
# # img = np.resize(img,(self.args.image_size, self.args.image_size,img.shape[-1]))
# # mask = np.resize(mask,(self.args.out_size,self.args.out_size,mask.shape[-1]))
img = torch.tensor(img).unsqueeze(0)
mask = torch.tensor(mask).unsqueeze(0)
if self.prompt == 'click':
point_label, pt = random_click(np.array(mask), point_label)
# if self.transform:
# state = torch.get_rng_state()
# img = self.transform(img)
# torch.set_rng_state(state)
# if self.transform_msk:
# mask = self.transform_msk(mask)
# # if (inout == 0 and point_label == 1) or (inout == 1 and point_label == 0):
# # mask = 1 - mask
name = img_name
image_meta_dict = {'filename_or_obj':name}
return {
'image':img,
'label': mask,
'p_label':point_label,
'pt':pt,
'image_meta_dict':image_meta_dict,
}