introvoyz041's picture
Upload folder using huggingface_hub
3f31c34 verified
import argparse
import os
import shutil
import sys
import tempfile
import time
from collections import OrderedDict
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from einops import rearrange
from monai.inferers import sliding_window_inference
from monai.losses import DiceCELoss
from monai.transforms import AsDiscrete
from PIL import Image
from skimage import io
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score
from tensorboardX import SummaryWriter
#from dataset import *
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import cfg
import models.sam.utils.transforms as samtrans
import pytorch_ssim
#from models.discriminatorlayer import discriminator
from conf import settings
from utils import *
# from lucent.modelzoo.util import get_model_layers
# from lucent.optvis import render, param, transform, objectives
# from lucent.modelzoo import inceptionv1
args = cfg.parse_args()
GPUdevice = torch.device('cuda', args.gpu_device)
pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2
criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
seed = torch.randint(1,11,(args.b,7))
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
scaler = torch.cuda.amp.GradScaler()
max_iterations = settings.EPOCH
post_label = AsDiscrete(to_onehot=14)
post_pred = AsDiscrete(argmax=True, to_onehot=14)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
def train_sam(args, net: nn.Module, optimizer, train_loader,
epoch, writer, schedulers=None, vis = 50):
hard = 0
epoch_loss = 0
ind = 0
# train mode
net.train()
optimizer.zero_grad()
epoch_loss = 0
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar:
for pack in train_loader:
# torch.cuda.empty_cache()
imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masks = pack['label'].to(dtype = torch.float32, device = GPUdevice)
# for k,v in pack['image_meta_dict'].items():
# print(k)
if 'pt' not in pack:
imgs, pt, masks = generate_click_prompt(imgs, masks)
else:
pt = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
if args.thd:
imgs, pt, masks = generate_click_prompt(imgs, masks)
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1,3,1,1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size,c,w,h = imgs.size()
longsize = w if w >=h else h
if point_labels.clone().flatten()[0] != -1:
# point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
if(len(point_labels.shape)==1): # only one point prompt
coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :]
pt = (coords_torch, labels_torch)
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
#true_mask_ave = cons_tensor(true_mask_ave)
# imgs = imgs.to(dtype = mask_type,device = GPUdevice)
'''Train'''
if args.mod == 'sam_adpt':
for n, value in net.image_encoder.named_parameters():
if "Adapter" not in n:
value.requires_grad = False
else:
value.requires_grad = True
elif args.mod == 'sam_lora' or args.mod == 'sam_adalora':
from models.common import loralib as lora
lora.mark_only_lora_as_trainable(net.image_encoder)
if args.mod == 'sam_adalora':
# Initialize the RankAllocator
rankallocator = lora.RankAllocator(
net.image_encoder, lora_r=4, target_rank=8,
init_warmup=500, final_warmup=1500, mask_interval=10,
total_step=3000, beta1=0.85, beta2=0.85,
)
else:
for n, value in net.image_encoder.named_parameters():
value.requires_grad = True
imge= net.image_encoder(imgs)
with torch.no_grad():
if args.net == 'sam' or args.net == 'mobile_sam':
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
elif args.net == "efficient_sam":
coords_torch,labels_torch = transform_prompt(coords_torch,labels_torch,h,w)
se = net.prompt_encoder(
coords=coords_torch,
labels=labels_torch,
)
if args.net == 'sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=(args.multimask_output > 1),
)
elif args.net == 'mobile_sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
elif args.net == "efficient_sam":
se = se.view(
se.shape[0],
1,
se.shape[1],
se.shape[2],
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
multimask_output=False,
)
# Resize to the ordered output size
pred = F.interpolate(pred,size=(args.out_size,args.out_size))
loss = lossfunc(pred, masks)
pbar.set_postfix(**{'loss (batch)': loss.item()})
epoch_loss += loss.item()
# nn.utils.clip_grad_value_(net.parameters(), 0.1)
if args.mod == 'sam_adalora':
(loss+lora.compute_orth_regu(net, regu_weight=0.1)).backward()
optimizer.step()
rankallocator.update_and_mask(net, ind)
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()
'''vis images'''
if vis:
if ind % vis == 0:
namecat = 'Train'
for na in name[:2]:
namecat = namecat + na.split('/')[-1].split('.')[0] + '+'
vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
pbar.update()
return loss
def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
# eval mode
net.eval()
mask_type = torch.float32
n_val = len(val_loader) # the number of batch
ave_res, mix_res = (0,0,0,0), (0,)*args.multimask_output*2
rater_res = [(0,0,0,0) for _ in range(6)]
tot = 0
hard = 0
threshold = (0.1, 0.3, 0.5, 0.7, 0.9)
GPUdevice = torch.device('cuda:' + str(args.gpu_device))
device = GPUdevice
if args.thd:
lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
else:
lossfunc = criterion_G
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
for ind, pack in enumerate(val_loader):
imgsw = pack['image'].to(dtype = torch.float32, device = GPUdevice)
masksw = pack['label'].to(dtype = torch.float32, device = GPUdevice)
# for k,v in pack['image_meta_dict'].items():
# print(k)
if 'pt' not in pack or args.thd:
imgsw, ptw, masksw = generate_click_prompt(imgsw, masksw)
else:
ptw = pack['pt']
point_labels = pack['p_label']
name = pack['image_meta_dict']['filename_or_obj']
buoy = 0
if args.evl_chunk:
evl_ch = int(args.evl_chunk)
else:
evl_ch = int(imgsw.size(-1))
while (buoy + evl_ch) <= imgsw.size(-1):
if args.thd:
pt = ptw[:,:,buoy: buoy + evl_ch]
else:
pt = ptw
imgs = imgsw[...,buoy:buoy + evl_ch]
masks = masksw[...,buoy:buoy + evl_ch]
buoy += evl_ch
if args.thd:
pt = rearrange(pt, 'b n d -> (b d) n')
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ')
masks = rearrange(masks, 'b c h w d -> (b d) c h w ')
imgs = imgs.repeat(1,3,1,1)
point_labels = torch.ones(imgs.size(0))
imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs)
masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks)
showp = pt
mask_type = torch.float32
ind += 1
b_size,c,w,h = imgs.size()
longsize = w if w >=h else h
if point_labels.clone().flatten()[0] != -1:
# point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w))
point_coords = pt
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice)
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice)
if(len(point_labels.shape)==1): # only one point prompt
coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :]
pt = (coords_torch, labels_torch)
'''init'''
if hard:
true_mask_ave = (true_mask_ave > 0.5).float()
#true_mask_ave = cons_tensor(true_mask_ave)
imgs = imgs.to(dtype = mask_type,device = GPUdevice)
'''test'''
with torch.no_grad():
imge= net.image_encoder(imgs)
if args.net == 'sam' or args.net == 'mobile_sam':
se, de = net.prompt_encoder(
points=pt,
boxes=None,
masks=None,
)
elif args.net == "efficient_sam":
coords_torch,labels_torch = transform_prompt(coords_torch,labels_torch,h,w)
se = net.prompt_encoder(
coords=coords_torch,
labels=labels_torch,
)
if args.net == 'sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=(args.multimask_output > 1),
)
elif args.net == 'mobile_sam':
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
dense_prompt_embeddings=de,
multimask_output=False,
)
elif args.net == "efficient_sam":
se = se.view(
se.shape[0],
1,
se.shape[1],
se.shape[2],
)
pred, _ = net.mask_decoder(
image_embeddings=imge,
image_pe=net.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=se,
multimask_output=False,
)
# Resize to the ordered output size
pred = F.interpolate(pred,size=(args.out_size,args.out_size))
tot += lossfunc(pred, masks)
'''vis images'''
if ind % args.vis == 0:
namecat = 'Test'
for na in name[:2
]:
img_name = na.split('/')[-1].split('.')[0]
namecat = namecat + img_name + '+'
vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp)
temp = eval_seg(pred, masks, threshold)
mix_res = tuple([sum(a) for a in zip(mix_res, temp)])
pbar.update()
if args.evl_chunk:
n_val = n_val * (imgsw.size(-1) // evl_ch)
return tot/ n_val , tuple([a/n_val for a in mix_res])
def transform_prompt(coord,label,h,w):
coord = coord.transpose(0,1)
label = label.transpose(0,1)
coord = coord.unsqueeze(1)
label = label.unsqueeze(1)
batch_size, max_num_queries, num_pts, _ = coord.shape
num_pts = coord.shape[2]
rescaled_batched_points = get_rescaled_pts(coord, h, w)
decoder_max_num_input_points = 6
if num_pts > decoder_max_num_input_points:
rescaled_batched_points = rescaled_batched_points[
:, :, : decoder_max_num_input_points, :
]
label = label[
:, :, : decoder_max_num_input_points
]
elif num_pts < decoder_max_num_input_points:
rescaled_batched_points = F.pad(
rescaled_batched_points,
(0, 0, 0, decoder_max_num_input_points - num_pts),
value=-1.0,
)
label = F.pad(
label,
(0, decoder_max_num_input_points - num_pts),
value=-1.0,
)
rescaled_batched_points = rescaled_batched_points.reshape(
batch_size * max_num_queries, decoder_max_num_input_points, 2
)
label = label.reshape(
batch_size * max_num_queries, decoder_max_num_input_points
)
return rescaled_batched_points,label
def get_rescaled_pts(batched_points: torch.Tensor, input_h: int, input_w: int):
return torch.stack(
[
torch.where(
batched_points[..., 0] >= 0,
batched_points[..., 0] * 1024 / input_w,
-1.0,
),
torch.where(
batched_points[..., 1] >= 0,
batched_points[..., 1] * 1024 / input_h,
-1.0,
),
],
dim=-1,
)