|
import numpy as np |
|
from tqdm import tqdm, trange |
|
import os |
|
import argparse |
|
from glob import glob |
|
import torch |
|
from torch import utils |
|
from torch.nn import functional as F |
|
from torchvision.transforms import functional as TF |
|
from torchvision.transforms import InterpolationMode |
|
|
|
from video_module.dataset import Video_DS |
|
from video_module.model import AFB_URR, FeatureBank |
|
from test_image_seg import test_waterseg |
|
import myutils |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(description='V-FloodNet: Water Video Segmentation') |
|
parser.add_argument('--gpu', type=int, default=0, help='GPU card id.') |
|
parser.add_argument('--budget', type=int, default=250000, help='Max number of features in the feature bank.') |
|
parser.add_argument('--viz', action='store_true', default=True, help='Visualize data.') |
|
parser.add_argument('--model-path', type=str, required=True, help='Path to the checkpoint.') |
|
parser.add_argument('--update-rate', type=float, default=0.1, help='Update Rate for merging new features.') |
|
parser.add_argument('--merge-thres', type=float, default=0.95, help='Merging Rate threshold.') |
|
parser.add_argument('--test-path', type=str, required=True, help='Path to the test video frames.') |
|
parser.add_argument('--test-name', type=str, required=True, help='Name for the test video.') |
|
return parser.parse_args() |
|
|
|
|
|
def main(args, device): |
|
model = AFB_URR(device, update_bank=True, load_imagenet_params=False) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
downsample_size = 480 |
|
|
|
if os.path.isfile(args.model_path): |
|
checkpoint = torch.load(args.model_path) |
|
end_epoch = checkpoint['epoch'] |
|
model.load_state_dict(checkpoint['model'], strict=False) |
|
train_loss = checkpoint['loss'] |
|
seed = checkpoint['seed'] |
|
print(myutils.gct(), |
|
f'Loaded checkpoint {args.model_path}. (end_epoch: {end_epoch}, train_loss: {train_loss}, seed: {seed})') |
|
else: |
|
print(myutils.gct(), f'No checkpoint found at {args.model_path}') |
|
raise IOError |
|
|
|
img_list = sorted(glob(os.path.join(args.test_path, '*.jpg')) + glob(os.path.join(args.test_path, '*.png'))) |
|
first_frame = myutils.load_image_in_PIL(img_list[0]) |
|
first_name = os.path.basename(img_list[0])[:-4] |
|
|
|
out_dir = './output/segs' |
|
mask_dir = os.path.join(out_dir, args.test_name, 'mask') |
|
mask_path = os.path.join(mask_dir, first_name + '.png') |
|
if not os.path.exists(mask_path): |
|
image_model_path = './records/link_efficientb4_model.pth' |
|
test_waterseg(image_model_path, img_list[0], args.test_name, out_dir, device) |
|
|
|
first_mask = myutils.load_image_in_PIL(mask_path, 'P') |
|
seq_dataset = Video_DS(img_list, first_frame, first_mask) |
|
|
|
seq_loader = utils.data.DataLoader(seq_dataset, batch_size=1, shuffle=False, num_workers=1) |
|
|
|
seg_dir = os.path.join(out_dir, args.test_name, 'mask') |
|
os.makedirs(seg_dir, exist_ok=True) |
|
if args.viz: |
|
overlay_dir = os.path.join(out_dir, args.test_name, 'overlay') |
|
os.makedirs(overlay_dir, exist_ok=True) |
|
|
|
obj_n = seq_dataset.obj_n |
|
fb = FeatureBank(obj_n, args.budget, device, update_rate=args.update_rate, thres_close=args.merge_thres) |
|
|
|
ori_first_frame = seq_dataset.first_frame.unsqueeze(0).to(device) |
|
ori_first_mask = seq_dataset.first_mask.unsqueeze(0).to(device) |
|
|
|
first_frame = TF.resize(ori_first_frame, downsample_size, InterpolationMode.BICUBIC) |
|
first_mask = TF.resize(ori_first_mask, downsample_size, InterpolationMode.NEAREST) |
|
|
|
pred = torch.argmax(ori_first_mask[0], dim=0).cpu().numpy().astype(np.uint8) |
|
seg_path = os.path.join(seg_dir, f'{first_name}.png') |
|
myutils.save_seg_mask(pred, seg_path, myutils.color_palette) |
|
|
|
if args.viz: |
|
overlay_path = os.path.join(overlay_dir, f'{first_name}.png') |
|
myutils.save_overlay(ori_first_frame[0], pred, overlay_path, myutils.color_palette) |
|
|
|
with torch.no_grad(): |
|
k4_list, v4_list = model.memorize(first_frame, first_mask) |
|
fb.init_bank(k4_list, v4_list) |
|
|
|
for idx, (frame, frame_name) in enumerate(tqdm(seq_loader)): |
|
|
|
ori_frame = frame.to(device) |
|
ori_size = ori_frame.shape[-2:] |
|
frame = TF.resize(ori_frame, downsample_size, InterpolationMode.BICUBIC) |
|
score, _ = model.segment(frame, fb) |
|
pred_mask = F.softmax(score, dim=1) |
|
|
|
k4_list, v4_list = model.memorize(frame, pred_mask) |
|
fb.update(k4_list, v4_list, idx + 1) |
|
|
|
pred = TF.resize(pred_mask, ori_size, InterpolationMode.BICUBIC) |
|
pred = torch.argmax(pred[0], dim=0).cpu().numpy().astype(np.uint8) |
|
pred = myutils.postprocessing_pred(pred) |
|
seg_path = os.path.join(seg_dir, f'{frame_name[0]}.png') |
|
myutils.save_seg_mask(pred, seg_path, myutils.color_palette) |
|
if args.viz: |
|
overlay_path = os.path.join(overlay_dir, f'{frame_name[0]}.png') |
|
myutils.save_overlay(ori_frame[0], pred, overlay_path, myutils.color_palette) |
|
|
|
fb.print_peak_mem() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
args = get_args() |
|
print(myutils.gct(), 'Args =', args) |
|
|
|
if args.gpu >= 0 and torch.cuda.is_available(): |
|
device = torch.device('cuda', args.gpu) |
|
else: |
|
raise ValueError('CUDA is required. --gpu must be >= 0.') |
|
|
|
assert os.path.isdir(args.test_path) |
|
|
|
main(args, device) |
|
|
|
print(myutils.gct(), 'Test video segmentation done.') |
|
|