rlawjdghek's picture
detectron2
a9a0ec2
raw
history blame
3.48 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import unittest
import tempfile
from itertools import count
from detectron2.config import LazyConfig, LazyCall as L
from omegaconf import DictConfig
class TestLazyPythonConfig(unittest.TestCase):
def setUp(self):
self.curr_dir = os.path.dirname(__file__)
self.root_filename = os.path.join(self.curr_dir, "root_cfg.py")
def test_load(self):
cfg = LazyConfig.load(self.root_filename)
self.assertEqual(cfg.dir1a_dict.a, "modified")
self.assertEqual(cfg.dir1b_dict.a, 1)
self.assertEqual(cfg.lazyobj.x, "base_a_1")
cfg.lazyobj.x = "new_x"
# reload
cfg = LazyConfig.load(self.root_filename)
self.assertEqual(cfg.lazyobj.x, "base_a_1")
def test_save_load(self):
cfg = LazyConfig.load(self.root_filename)
with tempfile.TemporaryDirectory(prefix="detectron2") as d:
fname = os.path.join(d, "test_config.yaml")
LazyConfig.save(cfg, fname)
cfg2 = LazyConfig.load(fname)
self.assertEqual(cfg2.lazyobj._target_, "itertools.count")
self.assertEqual(cfg.lazyobj._target_, count)
cfg2.lazyobj.pop("_target_")
cfg.lazyobj.pop("_target_")
# the rest are equal
self.assertEqual(cfg, cfg2)
def test_failed_save(self):
cfg = DictConfig({"x": lambda: 3}, flags={"allow_objects": True})
with tempfile.TemporaryDirectory(prefix="detectron2") as d:
fname = os.path.join(d, "test_config.yaml")
LazyConfig.save(cfg, fname)
self.assertTrue(os.path.exists(fname))
self.assertTrue(os.path.exists(fname + ".pkl"))
def test_overrides(self):
cfg = LazyConfig.load(self.root_filename)
LazyConfig.apply_overrides(cfg, ["lazyobj.x=123", 'dir1b_dict.a="123"'])
self.assertEqual(cfg.dir1b_dict.a, "123")
self.assertEqual(cfg.lazyobj.x, 123)
LazyConfig.apply_overrides(cfg, ["dir1b_dict.a=abc"])
self.assertEqual(cfg.dir1b_dict.a, "abc")
def test_invalid_overrides(self):
cfg = LazyConfig.load(self.root_filename)
with self.assertRaises(KeyError):
LazyConfig.apply_overrides(cfg, ["lazyobj.x.xxx=123"])
def test_to_py(self):
cfg = LazyConfig.load(self.root_filename)
cfg.lazyobj.x = {"a": 1, "b": 2, "c": L(count)(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]})}
cfg.list = ["a", 1, "b", 3.2]
py_str = LazyConfig.to_py(cfg)
expected = """cfg.dir1a_dict.a = "modified"
cfg.dir1a_dict.b = 2
cfg.dir1b_dict.a = 1
cfg.dir1b_dict.b = 2
cfg.lazyobj = itertools.count(
x={
"a": 1,
"b": 2,
"c": itertools.count(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]}),
},
y="base_a_1_from_b",
)
cfg.list = ["a", 1, "b", 3.2]
"""
self.assertEqual(py_str, expected)
def test_bad_import(self):
file = os.path.join(self.curr_dir, "dir1", "bad_import.py")
with self.assertRaisesRegex(ImportError, "relative import"):
LazyConfig.load(file)
def test_bad_import2(self):
file = os.path.join(self.curr_dir, "dir1", "bad_import2.py")
with self.assertRaisesRegex(ImportError, "not exist"):
LazyConfig.load(file)
def test_load_rel(self):
file = os.path.join(self.curr_dir, "dir1", "load_rel.py")
cfg = LazyConfig.load(file)
self.assertIn("x", cfg)