File size: 926 Bytes
62409ee e88e8f7 62409ee b60b62e c47be26 621caea 62409ee 19e28e3 62409ee 19e28e3 62409ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import sys
from pathlib import Path
import wandb
FILE = Path(__file__).resolve()
sys.path.append(FILE.parents[3].as_posix()) # add utils/ to path
from train import train, parse_opt
from utils.general import increment_path
from utils.torch_utils import select_device
from utils.callbacks import Callbacks
def sweep():
wandb.init()
# Get hyp dict from sweep agent
hyp_dict = vars(wandb.config).get("_items")
# Workaround: get necessary opt args
opt = parse_opt(known=True)
opt.batch_size = hyp_dict.get("batch_size")
opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve))
opt.epochs = hyp_dict.get("epochs")
opt.nosave = True
opt.data = hyp_dict.get("data")
device = select_device(opt.device, batch_size=opt.batch_size)
# train
train(hyp_dict, opt, device, callbacks=Callbacks())
if __name__ == "__main__":
sweep()
|