File size: 3,058 Bytes
3f31c34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import pickle

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset

from utils import random_box, random_click


class LIDC(Dataset):
    names = []
    images = []
    labels = []
    series_uid = []

    def __init__(self, data_path, transform=None, transform_msk = None, prompt = 'click'):
        self.prompt = prompt
        self.transform = transform
        self.transform_msk = transform_msk
        
        max_bytes = 2**31 - 1
        data = {}
        for file in os.listdir(data_path):
            filename = os.fsdecode(file)
            if '.pickle' in filename:
                file_path = data_path + filename
                bytes_in = bytearray(0)
                input_size = os.path.getsize(file_path)
                with open(file_path, 'rb') as f_in:
                    for _ in range(0, input_size, max_bytes):
                        bytes_in += f_in.read(max_bytes)
                new_data = pickle.loads(bytes_in)
                data.update(new_data)
                
        
        for key, value in data.items():
            self.names.append(key)
            self.images.append(value['image'].astype(float))
            self.labels.append(value['masks'])
            self.series_uid.append(value['series_uid'])

        assert (len(self.images) == len(self.labels) == len(self.series_uid))

        for img in self.images:
            assert np.max(img) <= 1 and np.min(img) >= 0
        for label in self.labels:
            assert np.max(label) <= 1 and np.min(label) >= 0

        del new_data
        del data

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):

        point_label = 1

        """Get the images"""
        img = np.expand_dims(self.images[index], axis=0)
        name = self.names[index]
        multi_rater = self.labels[index]

        # first click is the target most agreement among raters, otherwise, background agreement
        if self.prompt == 'click':
            point_label, pt = random_click(np.array(np.mean(np.stack(multi_rater), axis=0)) / 255, point_label)

        # Convert image (ensure three channels) and multi-rater labels to torch tensors
        img = torch.from_numpy(img).type(torch.float32)
        img = img.repeat(3, 1, 1) 
        multi_rater = [torch.from_numpy(single_rater).type(torch.float32) for single_rater in multi_rater]

        multi_rater = torch.stack(multi_rater, dim=0)
        multi_rater = multi_rater.unsqueeze(1)

        if self.prompt == 'box':
            x_min, x_max, y_min, y_max = random_box(multi_rater)
            box = [x_min, x_max, y_min, y_max]

        mask = multi_rater.mean(dim=0) # average

        image_meta_dict = {'filename_or_obj':name}
        return {
            'image':img,
            'multi_rater': multi_rater,
            'label': mask,
            'p_label':point_label,
            'pt':pt,
            'box': box,
            'image_meta_dict':image_meta_dict,
        }