Spaces:
Configuration error
Configuration error
import argparse | |
import os | |
import shutil | |
import sys | |
import tempfile | |
import time | |
from collections import OrderedDict | |
from datetime import datetime | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
from einops import rearrange | |
from monai.inferers import sliding_window_inference | |
from monai.losses import DiceCELoss | |
from monai.transforms import AsDiscrete | |
from PIL import Image | |
from skimage import io | |
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score | |
from tensorboardX import SummaryWriter | |
#from dataset import * | |
from torch.autograd import Variable | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
import cfg | |
import models.sam.utils.transforms as samtrans | |
import pytorch_ssim | |
#from models.discriminatorlayer import discriminator | |
from conf import settings | |
from utils import * | |
# from lucent.modelzoo.util import get_model_layers | |
# from lucent.optvis import render, param, transform, objectives | |
# from lucent.modelzoo import inceptionv1 | |
args = cfg.parse_args() | |
GPUdevice = torch.device('cuda', args.gpu_device) | |
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2 | |
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) | |
seed = torch.randint(1,11,(args.b,7)) | |
torch.backends.cudnn.benchmark = True | |
loss_function = DiceCELoss(to_onehot_y=True, softmax=True) | |
scaler = torch.cuda.amp.GradScaler() | |
max_iterations = settings.EPOCH | |
post_label = AsDiscrete(to_onehot=14) | |
post_pred = AsDiscrete(argmax=True, to_onehot=14) | |
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) | |
dice_val_best = 0.0 | |
global_step_best = 0 | |
epoch_loss_values = [] | |
metric_values = [] | |
def train_sam(args, net: nn.Module, optimizer, train_loader, | |
epoch, writer, schedulers=None, vis = 50): | |
hard = 0 | |
epoch_loss = 0 | |
ind = 0 | |
# train mode | |
net.train() | |
optimizer.zero_grad() | |
epoch_loss = 0 | |
GPUdevice = torch.device('cuda:' + str(args.gpu_device)) | |
if args.thd: | |
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean') | |
else: | |
lossfunc = criterion_G | |
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar: | |
for pack in train_loader: | |
# torch.cuda.empty_cache() | |
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice) | |
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice) | |
# for k,v in pack['image_meta_dict'].items(): | |
# print(k) | |
if 'pt' not in pack: | |
imgs, pt, masks = generate_click_prompt(imgs, masks) | |
else: | |
pt = pack['pt'] | |
point_labels = pack['p_label'] | |
name = pack['image_meta_dict']['filename_or_obj'] | |
if args.thd: | |
imgs, pt, masks = generate_click_prompt(imgs, masks) | |
pt = rearrange(pt, 'b n d -> (b d) n') | |
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ') | |
masks = rearrange(masks, 'b c h w d -> (b d) c h w ') | |
imgs = imgs.repeat(1,3,1,1) | |
point_labels = torch.ones(imgs.size(0)) | |
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs) | |
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks) | |
showp = pt | |
mask_type = torch.float32 | |
ind += 1 | |
b_size,c,w,h = imgs.size() | |
longsize = w if w >=h else h | |
if point_labels.clone().flatten()[0] != -1: | |
# point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w)) | |
point_coords = pt | |
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice) | |
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice) | |
if(len(point_labels.shape)==1): # only one point prompt | |
coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :] | |
pt = (coords_torch, labels_torch) | |
'''init''' | |
if hard: | |
true_mask_ave = (true_mask_ave > 0.5).float() | |
#true_mask_ave = cons_tensor(true_mask_ave) | |
# imgs = imgs.to(dtype = mask_type,device = GPUdevice) | |
'''Train''' | |
if args.mod == 'sam_adpt': | |
for n, value in net.image_encoder.named_parameters(): | |
if "Adapter" not in n: | |
value.requires_grad = False | |
else: | |
value.requires_grad = True | |
elif args.mod == 'sam_lora' or args.mod == 'sam_adalora': | |
from models.common import loralib as lora | |
lora.mark_only_lora_as_trainable(net.image_encoder) | |
if args.mod == 'sam_adalora': | |
# Initialize the RankAllocator | |
rankallocator = lora.RankAllocator( | |
net.image_encoder, lora_r=4, target_rank=8, | |
init_warmup=500, final_warmup=1500, mask_interval=10, | |
total_step=3000, beta1=0.85, beta2=0.85, | |
) | |
else: | |
for n, value in net.image_encoder.named_parameters(): | |
value.requires_grad = True | |
imge= net.image_encoder(imgs) | |
with torch.no_grad(): | |
if args.net == 'sam' or args.net == 'mobile_sam': | |
se, de = net.prompt_encoder( | |
points=pt, | |
boxes=None, | |
masks=None, | |
) | |
elif args.net == "efficient_sam": | |
coords_torch,labels_torch = transform_prompt(coords_torch,labels_torch,h,w) | |
se = net.prompt_encoder( | |
coords=coords_torch, | |
labels=labels_torch, | |
) | |
if args.net == 'sam': | |
pred, _ = net.mask_decoder( | |
image_embeddings=imge, | |
image_pe=net.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=se, | |
dense_prompt_embeddings=de, | |
multimask_output=(args.multimask_output > 1), | |
) | |
elif args.net == 'mobile_sam': | |
pred, _ = net.mask_decoder( | |
image_embeddings=imge, | |
image_pe=net.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=se, | |
dense_prompt_embeddings=de, | |
multimask_output=False, | |
) | |
elif args.net == "efficient_sam": | |
se = se.view( | |
se.shape[0], | |
1, | |
se.shape[1], | |
se.shape[2], | |
) | |
pred, _ = net.mask_decoder( | |
image_embeddings=imge, | |
image_pe=net.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=se, | |
multimask_output=False, | |
) | |
# Resize to the ordered output size | |
pred = F.interpolate(pred,size=(args.out_size,args.out_size)) | |
loss = lossfunc(pred, masks) | |
pbar.set_postfix(**{'loss (batch)': loss.item()}) | |
epoch_loss += loss.item() | |
# nn.utils.clip_grad_value_(net.parameters(), 0.1) | |
if args.mod == 'sam_adalora': | |
(loss+lora.compute_orth_regu(net, regu_weight=0.1)).backward() | |
optimizer.step() | |
rankallocator.update_and_mask(net, ind) | |
else: | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
'''vis images''' | |
if vis: | |
if ind % vis == 0: | |
namecat = 'Train' | |
for na in name[:2]: | |
namecat = namecat + na.split('/')[-1].split('.')[0] + '+' | |
vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp) | |
pbar.update() | |
return loss | |
def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): | |
# eval mode | |
net.eval() | |
mask_type = torch.float32 | |
n_val = len(val_loader) # the number of batch | |
ave_res, mix_res = (0,0,0,0), (0,)*args.multimask_output*2 | |
rater_res = [(0,0,0,0) for _ in range(6)] | |
tot = 0 | |
hard = 0 | |
threshold = (0.1, 0.3, 0.5, 0.7, 0.9) | |
GPUdevice = torch.device('cuda:' + str(args.gpu_device)) | |
device = GPUdevice | |
if args.thd: | |
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean') | |
else: | |
lossfunc = criterion_G | |
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: | |
for ind, pack in enumerate(val_loader): | |
imgsw = pack['image'].to(dtype = torch.float32, device = GPUdevice) | |
masksw = pack['label'].to(dtype = torch.float32, device = GPUdevice) | |
# for k,v in pack['image_meta_dict'].items(): | |
# print(k) | |
if 'pt' not in pack or args.thd: | |
imgsw, ptw, masksw = generate_click_prompt(imgsw, masksw) | |
else: | |
ptw = pack['pt'] | |
point_labels = pack['p_label'] | |
name = pack['image_meta_dict']['filename_or_obj'] | |
buoy = 0 | |
if args.evl_chunk: | |
evl_ch = int(args.evl_chunk) | |
else: | |
evl_ch = int(imgsw.size(-1)) | |
while (buoy + evl_ch) <= imgsw.size(-1): | |
if args.thd: | |
pt = ptw[:,:,buoy: buoy + evl_ch] | |
else: | |
pt = ptw | |
imgs = imgsw[...,buoy:buoy + evl_ch] | |
masks = masksw[...,buoy:buoy + evl_ch] | |
buoy += evl_ch | |
if args.thd: | |
pt = rearrange(pt, 'b n d -> (b d) n') | |
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ') | |
masks = rearrange(masks, 'b c h w d -> (b d) c h w ') | |
imgs = imgs.repeat(1,3,1,1) | |
point_labels = torch.ones(imgs.size(0)) | |
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs) | |
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks) | |
showp = pt | |
mask_type = torch.float32 | |
ind += 1 | |
b_size,c,w,h = imgs.size() | |
longsize = w if w >=h else h | |
if point_labels.clone().flatten()[0] != -1: | |
# point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w)) | |
point_coords = pt | |
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice) | |
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice) | |
if(len(point_labels.shape)==1): # only one point prompt | |
coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :] | |
pt = (coords_torch, labels_torch) | |
'''init''' | |
if hard: | |
true_mask_ave = (true_mask_ave > 0.5).float() | |
#true_mask_ave = cons_tensor(true_mask_ave) | |
imgs = imgs.to(dtype = mask_type,device = GPUdevice) | |
'''test''' | |
with torch.no_grad(): | |
imge= net.image_encoder(imgs) | |
if args.net == 'sam' or args.net == 'mobile_sam': | |
se, de = net.prompt_encoder( | |
points=pt, | |
boxes=None, | |
masks=None, | |
) | |
elif args.net == "efficient_sam": | |
coords_torch,labels_torch = transform_prompt(coords_torch,labels_torch,h,w) | |
se = net.prompt_encoder( | |
coords=coords_torch, | |
labels=labels_torch, | |
) | |
if args.net == 'sam': | |
pred, _ = net.mask_decoder( | |
image_embeddings=imge, | |
image_pe=net.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=se, | |
dense_prompt_embeddings=de, | |
multimask_output=(args.multimask_output > 1), | |
) | |
elif args.net == 'mobile_sam': | |
pred, _ = net.mask_decoder( | |
image_embeddings=imge, | |
image_pe=net.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=se, | |
dense_prompt_embeddings=de, | |
multimask_output=False, | |
) | |
elif args.net == "efficient_sam": | |
se = se.view( | |
se.shape[0], | |
1, | |
se.shape[1], | |
se.shape[2], | |
) | |
pred, _ = net.mask_decoder( | |
image_embeddings=imge, | |
image_pe=net.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=se, | |
multimask_output=False, | |
) | |
# Resize to the ordered output size | |
pred = F.interpolate(pred,size=(args.out_size,args.out_size)) | |
tot += lossfunc(pred, masks) | |
'''vis images''' | |
if ind % args.vis == 0: | |
namecat = 'Test' | |
for na in name[:2 | |
]: | |
img_name = na.split('/')[-1].split('.')[0] | |
namecat = namecat + img_name + '+' | |
vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp) | |
temp = eval_seg(pred, masks, threshold) | |
mix_res = tuple([sum(a) for a in zip(mix_res, temp)]) | |
pbar.update() | |
if args.evl_chunk: | |
n_val = n_val * (imgsw.size(-1) // evl_ch) | |
return tot/ n_val , tuple([a/n_val for a in mix_res]) | |
def transform_prompt(coord,label,h,w): | |
coord = coord.transpose(0,1) | |
label = label.transpose(0,1) | |
coord = coord.unsqueeze(1) | |
label = label.unsqueeze(1) | |
batch_size, max_num_queries, num_pts, _ = coord.shape | |
num_pts = coord.shape[2] | |
rescaled_batched_points = get_rescaled_pts(coord, h, w) | |
decoder_max_num_input_points = 6 | |
if num_pts > decoder_max_num_input_points: | |
rescaled_batched_points = rescaled_batched_points[ | |
:, :, : decoder_max_num_input_points, : | |
] | |
label = label[ | |
:, :, : decoder_max_num_input_points | |
] | |
elif num_pts < decoder_max_num_input_points: | |
rescaled_batched_points = F.pad( | |
rescaled_batched_points, | |
(0, 0, 0, decoder_max_num_input_points - num_pts), | |
value=-1.0, | |
) | |
label = F.pad( | |
label, | |
(0, decoder_max_num_input_points - num_pts), | |
value=-1.0, | |
) | |
rescaled_batched_points = rescaled_batched_points.reshape( | |
batch_size * max_num_queries, decoder_max_num_input_points, 2 | |
) | |
label = label.reshape( | |
batch_size * max_num_queries, decoder_max_num_input_points | |
) | |
return rescaled_batched_points,label | |
def get_rescaled_pts(batched_points: torch.Tensor, input_h: int, input_w: int): | |
return torch.stack( | |
[ | |
torch.where( | |
batched_points[..., 0] >= 0, | |
batched_points[..., 0] * 1024 / input_w, | |
-1.0, | |
), | |
torch.where( | |
batched_points[..., 1] >= 0, | |
batched_points[..., 1] * 1024 / input_h, | |
-1.0, | |
), | |
], | |
dim=-1, | |
) |