PULASki_ProbUNet2D_FID_MSSeg / ProbUNet_utils.py
soumickmj's picture
Upload ProbUNet
02a3ef3 verified
raw
history blame
8.41 kB
import argparse
import os
import shutil
import numpy as np
import torch
# from trixi.util import Config, GridSearch
def check_attributes(object_, attributes):
missing = []
for attr in attributes:
if not hasattr(object_, attr):
missing.append(attr)
if len(missing) > 0:
return False
else:
return True
def set_seeds(seed, cuda=True):
if not hasattr(seed, "__iter__"):
seed = (seed, seed, seed)
np.random.seed(seed[0])
torch.manual_seed(seed[1])
if cuda: torch.cuda.manual_seed_all(seed[2])
def make_onehot(array, labels=None, axis=1, newaxis=False):
# get labels if necessary
if labels is None:
labels = np.unique(array)
labels = list(map(lambda x: x.item(), labels))
# get target shape
new_shape = list(array.shape)
if newaxis:
new_shape.insert(axis, len(labels))
else:
new_shape[axis] = new_shape[axis] * len(labels)
# make zero array
if type(array) == np.ndarray:
new_array = np.zeros(new_shape, dtype=array.dtype)
elif torch.is_tensor(array):
new_array = torch.zeros(new_shape, dtype=array.dtype, device=array.device)
else:
raise TypeError("Onehot conversion undefined for object of type {}".format(type(array)))
# fill new array
n_seg_channels = 1 if newaxis else array.shape[axis]
for seg_channel in range(n_seg_channels):
for l, label in enumerate(labels):
new_slc = [slice(None), ] * len(new_shape)
slc = [slice(None), ] * len(array.shape)
new_slc[axis] = seg_channel * len(labels) + l
if not newaxis:
slc[axis] = seg_channel
new_array[tuple(new_slc)] = array[tuple(slc)] == label
return new_array
def match_to(x, ref, keep_axes=(1,)):
target_shape = list(ref.shape)
for i in keep_axes:
target_shape[i] = x.shape[i]
target_shape = tuple(target_shape)
if x.shape == target_shape:
pass
if x.dim() == 1:
x = x.unsqueeze(0)
if x.dim() == 2:
while x.dim() < len(target_shape):
x = x.unsqueeze(-1)
x = x.expand(*target_shape)
x = x.to(device=ref.device, dtype=ref.dtype)
return x
def make_slices(original_shape, patch_shape):
working_shape = original_shape[-len(patch_shape):]
splits = []
for i in range(len(working_shape)):
splits.append([])
for j in range(working_shape[i] // patch_shape[i]):
splits[i].append(slice(j*patch_shape[i], (j+1)*patch_shape[i]))
rest = working_shape[i] % patch_shape[i]
if rest > 0:
splits[i].append(slice((j+1)*patch_shape[i], (j+1)*patch_shape[i] + rest))
# now we have all slices for the individual dimensions
# we need their combinatorial combinations
slices = list(itertools.product(*splits))
for i in range(len(slices)):
slices[i] = [slice(None), ] * (len(original_shape) - len(patch_shape)) + list(slices[i])
return slices
def coordinate_grid_samples(mean, std, factor_std=5, scale_std=1.):
relative = np.linspace(-scale_std*factor_std, scale_std*factor_std, 2*factor_std+1)
positions = np.array([mean + i * std for i in relative]).T
axes = np.meshgrid(*positions)
axes = map(lambda x: list(x.ravel()), axes)
samples = list(zip(*axes))
samples = list(map(np.array, samples))
return samples
def get_default_experiment_parser():
parser = argparse.ArgumentParser()
parser.add_argument("base_dir", type=str, help="Working directory for experiment.")
parser.add_argument("-c", "--config", type=str, default=None, help="Path to a config file.")
parser.add_argument("-v", "--visdomlogger", action="store_true", help="Use visdomlogger.")
parser.add_argument("-tx", "--tensorboardxlogger", type=str, default=None)
parser.add_argument("-tl", "--telegramlogger", action="store_true")
parser.add_argument("-dc", "--default_config", type=str, default="DEFAULTS", help="Select a default Config")
parser.add_argument("-ad", "--automatic_description", action="store_true")
parser.add_argument("-r", "--resume", type=str, default=None, help="Path to resume from")
parser.add_argument("-irc", "--ignore_resume_config", action="store_true", help="Ignore Config in experiment we resume from.")
parser.add_argument("-test", "--test", action="store_true", help="Run test instead of training")
parser.add_argument("-g", "--grid", type=str, help="Path to a config for grid search")
parser.add_argument("-s", "--skip_existing", action="store_true", help="Skip configs for which an experiment exists, only for grid search")
parser.add_argument("-m", "--mods", type=str, nargs="+", default=None, help="Mods are Config stubs to update only relevant parts for a certain setup.")
parser.add_argument("-ct", "--copy_test", action="store_true", help="Copy test files to original experiment.")
return parser
def run_experiment(experiment, configs, args, mods=None, **kwargs):
# set a few defaults
if "explogger_kwargs" not in kwargs:
kwargs["explogger_kwargs"] = dict(folder_format="{experiment_name}_%Y%m%d-%H%M%S")
if "explogger_freq" not in kwargs:
kwargs["explogger_freq"] = 1
if "resume_save_types" not in kwargs:
kwargs["resume_save_types"] = ("model", "simple", "th_vars", "results")
config = Config(file_=args.config) if args.config is not None else Config()
config.update_missing(configs[args.default_config].deepcopy())
if args.mods is not None and mods is not None:
for mod in args.mods:
config.update(mods[mod])
config = Config(config=config, update_from_argv=True)
# GET EXISTING EXPERIMENTS TO BE ABLE TO SKIP CERTAIN CONFIGS
if args.skip_existing:
existing_configs = []
for exp in os.listdir(args.base_dir):
try:
existing_configs.append(Config(file_=os.path.join(args.base_dir, exp, "config", "config.json")))
except Exception as e:
pass
if args.grid is not None:
grid = GridSearch().read(args.grid)
else:
grid = [{}]
for combi in grid:
config.update(combi)
if args.skip_existing:
skip_this = False
for existing_config in existing_configs:
if existing_config.contains(config):
skip_this = True
break
if skip_this:
continue
if "backup_every" in config:
kwargs["save_checkpoint_every_epoch"] = config["backup_every"]
loggers = {}
if args.visdomlogger:
loggers["v"] = ("visdom", {}, 1)
if args.tensorboardxlogger is not None:
if args.tensorboardxlogger == "same":
loggers["tx"] = ("tensorboard", {}, 1)
else:
loggers["tx"] = ("tensorboard", {"target_dir": args.tensorboardxlogger}, 1)
if args.telegramlogger:
kwargs["use_telegram"] = True
if args.automatic_description:
difference_to_default = Config.difference_config_static(config, configs["DEFAULTS"]).flat(keep_lists=True, max_split_size=0, flatten_int=True)
description_str = ""
for key, val in difference_to_default.items():
val = val[0]
description_str = "{} = {}\n{}".format(key, val, description_str)
config.description = description_str
exp = experiment(config=config,
base_dir=args.base_dir,
resume=args.resume,
ignore_resume_config=args.ignore_resume_config,
loggers=loggers,
**kwargs)
trained = False
if args.resume is None or args.test is False:
exp.run()
trained = True
if args.test:
exp.run_test(setup=not trained)
if isinstance(args.resume, str) and exp.elog is not None and args.copy_test:
for f in glob.glob(os.path.join(exp.elog.save_dir, "test*")):
if os.path.isdir(f):
shutil.copytree(f, os.path.join(args.resume, "save", os.path.basename(f)))
else:
shutil.copy(f, os.path.join(args.resume, "save"))