3DFauna_demo / video3d /segmentation.py
kyleleey
first commit
98a77e0
raw
history blame
4.81 kB
import configargparse
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.utils as tvutils
import torchvision.transforms
from video3d.utils.segmentation_transforms import *
from video3d.utils.misc import setup_runtime
from video3d import networks
from video3d.trainer import Trainer
from video3d.dataloaders import SegmentationDataset
class Segmentation:
def __init__(self, cfgs, _):
self.cfgs = cfgs
self.device = cfgs.get('device', 'cpu')
self.total_loss = None
self.net = networks.EDDeconv(cin=3, cout=1, zdim=128, nf=64, activation=None)
self.optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, self.net.parameters()),
lr=cfgs.get('lr', 1e-4),
betas=(0.9, 0.999),
weight_decay=5e-4)
def load_model_state(self, cp):
self.net.load_state_dict(cp["net"])
def load_optimizer_state(self, cp):
self.net.load_state_dict(cp["optimizer"])
@staticmethod
def get_data_loaders(cfgs):
batch_size = cfgs.get('batch_size', 64)
num_workers = cfgs.get('num_workers', 4)
data_dir = cfgs.get('data_dir', './data')
img_size = cfgs.get('image_size', 64)
min_size = int(img_size * cfgs.get('aug_min_resize', 0.5))
max_size = int(img_size * cfgs.get('aug_max_resize', 2.0))
transform = Compose([RandomResize(min_size, max_size),
RandomHorizontalFlip(cfgs.get("aug_horizontal_flip", 0.4)),
RandomCrop(img_size),
ImageOnly(torchvision.transforms.ColorJitter(**cfgs.get("aug_color_jitter", {}))),
ImageOnly(torchvision.transforms.RandomGrayscale(cfgs.get("aug_grayscale", 0.2))),
ToTensor()])
train_loader = torch.utils.data.DataLoader(
SegmentationDataset(data_dir, is_validation=False, transform=transform, sequence_range=(0, 0.5)),
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True
)
transform = Compose([ToTensor()])
val_loader = torch.utils.data.DataLoader(
SegmentationDataset(data_dir, is_validation=True, transform=transform, sequence_range=(0.5, 1.0)),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True
)
return train_loader, val_loader, None
def get_state_dict(self):
return {
"net": self.net.state_dict(),
"optimizer": self.optimizer.state_dict()
}
def to(self, device):
self.device = device
self.net.to(device)
def set_train(self):
self.net.train()
def set_eval(self):
self.net.eval()
def backward(self):
self.optimizer.zero_grad()
self.total_loss.backward()
self.optimizer.step()
def forward(self, batch, visualize=False):
image, target = batch
image = image.to(self.device)*2 - 1
target = target[:, 0, :, :].to(self.device).unsqueeze(1)
pred = self.net(image)
self.total_loss = nn.functional.binary_cross_entropy_with_logits(pred, target)
metrics = {'loss': self.total_loss}
visuals = {}
if visualize:
visuals['rgb'] = self.image_visual(image, normalize=True, range=(-1, 1))
visuals['target'] = self.image_visual(target, normalize=True, range=(0, 1))
visuals['pred'] = self.image_visual(nn.functional.sigmoid(pred), normalize=True, range=(0, 1))
return metrics, visuals
return metrics
def visualize(self, logger, total_iter, max_bs=25):
pass
def save_results(self, save_dir):
pass
def save_scores(self, path):
pass
@staticmethod
def image_visual(tensor, **kwargs):
if tensor.shape[1] == 1:
tensor = tensor.repeat(1, 3, 1, 1)
n = int(tensor.shape[0]**0.5 + 0.5)
tensor = tvutils.make_grid(tensor.detach(), nrow=n, **kwargs).permute(1, 2, 0)
return torch.clamp(tensor[:, :, :3] * 255, 0, 255).byte().cpu()
if __name__ == "__main__":
parser = configargparse.ArgumentParser(description='Training configurations.')
parser.add_argument('--config', default="config/train_segmentation.yml", type=str, is_config_file=True,
help='Specify a config file path')
parser.add_argument('--gpu', default=1, type=int, help='Specify a GPU device')
parser.add_argument('--seed', default=0, type=int, help='Specify a random seed')
args, _ = parser.parse_known_args()
cfgs = setup_runtime(args)
trainer = Trainer(cfgs, Segmentation)
trainer.train()