BAAI
/

yuxin commited on
Commit
647883b
1 Parent(s): da9b19f

add readme

Browse files
Files changed (1) hide show
  1. README.md +192 -0
README.md ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ![image/jpeg](https://cdn-uploads.huggingface.co/production/uploads/6565b54a9bf6665f10f75441/no60wyvKDTD-WV3pCt2P5.jpeg)
3
+
4
+ Language: [EN / ZH]
5
+
6
+ The SegVol is a universal and interactive model for volumetric medical image segmentation. SegVol accepts point, box, and text prompts while output volumetric segmentation. By training on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.
7
+
8
+ SegVol是用于体积医学图像分割的通用交互式模型,可以使用点,框和文本作为prompt驱动模型,输出分割结果。
9
+
10
+ 通过在90k个无标签CT和6k个有标签CT上进行训练,该基础模型支持对200多个解剖类别进行分割。
11
+
12
+ [**Paper**](https://arxiv.org/abs/2311.13385), [**Code**](https://github.com/BAAI-DCAI/SegVol) 和 [**Demo**](https://huggingface.co/spaces/BAAI/SegVol) 已发布。
13
+
14
+ **Keywords**: 3D medical SAM, volumetric image segmentation
15
+
16
+ ## Quicktart
17
+
18
+ ### Requirements
19
+ ```bash
20
+ conda create -n segvol_transformers python=3.8
21
+ conda activate segvol_transformers
22
+ ```
23
+ [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) or higher version is required. Please also install the following support packages:
24
+
25
+ 需要 [pytorch v1.11.0](https://pytorch.org/get-started/previous-versions/) 或更高版本。另外请安装如下支持包:
26
+
27
+ ```bash
28
+ pip install 'monai[all]==0.9.0'
29
+ pip install einops==0.6.1
30
+ pip install transformers==4.18.0
31
+ pip install matplotlib
32
+ ```
33
+
34
+ ### Test script
35
+
36
+ ```python
37
+ from transformers import AutoModel, AutoTokenizer
38
+ import torch
39
+ import os
40
+
41
+ # get device
42
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
43
+
44
+ # load model
45
+ clip_tokenizer = AutoTokenizer.from_pretrained("BAAI/SegVol")
46
+ model = AutoModel.from_pretrained("BAAI/SegVol", trust_remote_code=True, test_mode=True)
47
+ model.model.text_encoder.tokenizer = clip_tokenizer
48
+ model.eval()
49
+ model.to(device)
50
+ print('model load done')
51
+
52
+ # set case path
53
+ ct_path = 'path/to/Case_image_00001_0000.nii.gz'
54
+ gt_path = 'path/to/Case_label_00001.nii.gz'
55
+
56
+ # set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
57
+ categories = ["liver", "kidney", "spleen", "pancreas"]
58
+
59
+ # generate npy data format
60
+ ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
61
+ # IF you have download our 25 processed datasets, you can skip to here with the processed ct_npy, gt_npy files
62
+
63
+ # go through zoom_transform to generate zoomout & zoomin views
64
+ data_item = model.processor.zoom_transform(ct_npy, gt_npy)
65
+
66
+ # add batch dim manually
67
+ data_item['image'], data_item['label'], data_item['zoom_out_image'], data_item['zoom_out_label'] = \
68
+ data_item['image'].unsqueeze(0).to(device), data_item['label'].unsqueeze(0).to(device), data_item['zoom_out_image'].unsqueeze(0).to(device), data_item['zoom_out_label'].unsqueeze(0).to(device)
69
+
70
+ # take liver as the example
71
+ cls_idx = 0
72
+
73
+ # text prompt
74
+ text_prompt = [categories[cls_idx]]
75
+
76
+ # point prompt
77
+ point_prompt, point_prompt_map = model.processor.point_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device) # inputs w/o batch dim, outputs w batch dim
78
+
79
+ # bbox prompt
80
+ bbox_prompt, bbox_prompt_map = model.processor.bbox_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device) # inputs w/o batch dim, outputs w batch dim
81
+
82
+ print('prompt done')
83
+
84
+ # segvol test forward
85
+ # use_zoom: use zoom-out-zoom-in
86
+ # point_prompt_group: use point prompt
87
+ # bbox_prompt_group: use bbox prompt
88
+ # text_prompt: use text prompt
89
+ logits_mask = model.forward_test(image=data_item['image'],
90
+ zoomed_image=data_item['zoom_out_image'],
91
+ # point_prompt_group=[point_prompt, point_prompt_map],
92
+ bbox_prompt_group=[bbox_prompt, bbox_prompt_map],
93
+ text_prompt=text_prompt,
94
+ use_zoom=True
95
+ )
96
+
97
+ # cal dice score
98
+ dice = model.processor.dice_score(logits_mask[0][0], data_item['label'][0][cls_idx], device)
99
+ print(dice)
100
+
101
+ # save prediction as nii.gz file
102
+ save_path='./Case_preds_00001.nii.gz'
103
+ model.processor.save_preds(ct_path, save_path, logits_mask[0][0],
104
+ start_coord=data_item['foreground_start_coord'],
105
+ end_coord=data_item['foreground_end_coord'])
106
+ print('done')
107
+ ```
108
+
109
+ ### Training script
110
+
111
+ ```python
112
+ from transformers import AutoModel, AutoTokenizer
113
+ import torch
114
+ import os
115
+
116
+ # get device
117
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
118
+
119
+ # load model
120
+ clip_tokenizer = AutoTokenizer.from_pretrained("BAAI/SegVol")
121
+ model = AutoModel.from_pretrained("BAAI/SegVol", trust_remote_code=True, test_mode=False)
122
+ model.model.text_encoder.tokenizer = clip_tokenizer
123
+ model.train()
124
+ model.to(device)
125
+ print('model load done')
126
+
127
+ # set case path
128
+ ct_path = 'path/to/Case_image_00001_0000.nii.gz'
129
+ gt_path = 'path/to/Case_label_00001.nii.gz'
130
+
131
+ # set categories, corresponding to the unique values(1, 2, 3, 4, ...) in ground truth mask
132
+ categories = ["liver", "kidney", "spleen", "pancreas"]
133
+
134
+ # generate npy data format
135
+ ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
136
+ # IF you have download our 25 processed datasets, you can skip to here with the processed ct_npy, gt_npy files
137
+
138
+ # go through train transform
139
+ data_item = model.processor.train_transform(ct_npy, gt_npy)
140
+
141
+ # training example
142
+ # add batch dim manually
143
+ image, gt3D = data_item["image"].unsqueeze(0).to(device), data_item["label"].unsqueeze(0).to(device) # add batch dim
144
+
145
+ loss_step_avg = 0
146
+ for cls_idx in range(len(categories)):
147
+ # optimizer.zero_grad()
148
+ organs_cls = categories[cls_idx]
149
+ labels_cls = gt3D[:, cls_idx]
150
+ loss = model.forward_train(image, train_organs=organs_cls, train_labels=labels_cls)
151
+ loss_step_avg += loss.item()
152
+ loss.backward()
153
+ # optimizer.step()
154
+
155
+ loss_step_avg /= len(categories)
156
+ print(f'AVG loss {loss_step_avg}')
157
+
158
+ # save ckpt
159
+ model.save_pretrained('./ckpt')
160
+ ```
161
+
162
+ ### Start with M3D-Seg dataset
163
+
164
+ We have released 25 open source datasets(M3D-Seg) for training SegVol, and these preprocessed data have been uploaded to [ModelScope](https://www.modelscope.cn/datasets/GoodBaiBai88/M3D-Seg/summary) and [HuggingFace](https://huggingface.co/datasets/GoodBaiBai88/M3D-Seg).
165
+ You can use the following script to easily load cases and insert them into Test script and Training script.
166
+
167
+ 我们已经发布了用于训练SegVol的25个开源数据集(M3D-Seg),并将预处理后的数据上传到了[ModelScope](https://www.modelscope.cn/datasets/GoodBaiBai88/M3D-Seg/summary)和[HuggingFace](https://huggingface.co/datasets/GoodBaiBai88/M3D-Seg)。
168
+ 您可以使用下面的script方便地载入,并插入到Test script和Training script中。
169
+
170
+ ```python
171
+ import json, os
172
+ M3D_Seg_path = 'path/to/M3D-Seg'
173
+
174
+ # select a dataset
175
+ dataset_code = '0000'
176
+
177
+ # load json dict
178
+ json_path = os.path.join(M3D_Seg_path, dataset_code, dataset_code + '.json')
179
+ with open(json_path, 'r') as f:
180
+ dataset_dict = json.load(f)
181
+
182
+ # get a case
183
+ ct_path = os.path.join(M3D_Seg_path, dataset_dict['train'][0]['image'])
184
+ gt_path = os.path.join(M3D_Seg_path, dataset_dict['train'][0]['label'])
185
+
186
+ # get categories
187
+ categories_dict = dataset_dict['labels']
188
+ categories = [x for _, x in categories_dict.items() if x != "background"]
189
+
190
+ # load npy data format
191
+ ct_npy, gt_npy = model.processor.load_uniseg_case(ct_path, gt_path)
192
+ ```