import inspect
import pickle
import platform
from types import GeneratorType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import catalogue
import pytest
try:
from pydantic.v1 import BaseModel, PositiveInt, StrictFloat, constr
from pydantic.v1.types import StrictBool
except ImportError:
from pydantic import BaseModel, StrictFloat, PositiveInt, constr # type: ignore
from pydantic.types import StrictBool # type: ignore
from confection import Config, ConfigValidationError
from confection.tests.util import Cat, make_tempdir, my_registry
from confection.util import Generator, partial
EXAMPLE_CONFIG = """
[optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
use_averages = true
[optimizer.learn_rate]
@schedules = "warmup_linear.v1"
initial_rate = 0.1
warmup_steps = 10000
total_steps = 100000
[pipeline]
[pipeline.classifier]
name = "classifier"
factory = "classifier"
[pipeline.classifier.model]
@layers = "ClassifierModel.v1"
hidden_depth = 1
hidden_width = 64
token_vector_width = 128
[pipeline.classifier.model.embedding]
@layers = "Embedding.v1"
width = ${pipeline.classifier.model:token_vector_width}
"""
OPTIMIZER_CFG = """
[optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
use_averages = true
[optimizer.learn_rate]
@schedules = "warmup_linear.v1"
initial_rate = 0.1
warmup_steps = 10000
total_steps = 100000
"""
class HelloIntsSchema(BaseModel):
hello: int
world: int
class Config:
extra = "forbid"
class DefaultsSchema(BaseModel):
required: int
optional: str = "default value"
class Config:
extra = "forbid"
class ComplexSchema(BaseModel):
outer_req: int
outer_opt: str = "default value"
level2_req: HelloIntsSchema
level2_opt: DefaultsSchema = DefaultsSchema(required=1)
good_catsie = {"@cats": "catsie.v1", "evil": False, "cute": True}
ok_catsie = {"@cats": "catsie.v1", "evil": False, "cute": False}
bad_catsie = {"@cats": "catsie.v1", "evil": True, "cute": True}
worst_catsie = {"@cats": "catsie.v1", "evil": True, "cute": False}
def test_validate_simple_config():
simple_config = {"hello": 1, "world": 2}
f, _, v = my_registry._fill(simple_config, HelloIntsSchema)
assert f == simple_config
assert v == simple_config
def test_invalidate_simple_config():
invalid_config = {"hello": 1, "world": "hi!"}
with pytest.raises(ConfigValidationError) as exc_info:
my_registry._fill(invalid_config, HelloIntsSchema)
error = exc_info.value
assert len(error.errors) == 1
assert "type_error.integer" in error.error_types
def test_invalidate_extra_args():
invalid_config = {"hello": 1, "world": 2, "extra": 3}
with pytest.raises(ConfigValidationError):
my_registry._fill(invalid_config, HelloIntsSchema)
def test_fill_defaults_simple_config():
valid_config = {"required": 1}
filled, _, v = my_registry._fill(valid_config, DefaultsSchema)
assert filled["required"] == 1
assert filled["optional"] == "default value"
invalid_config = {"optional": "some value"}
with pytest.raises(ConfigValidationError):
my_registry._fill(invalid_config, DefaultsSchema)
def test_fill_recursive_config():
valid_config = {"outer_req": 1, "level2_req": {"hello": 4, "world": 7}}
filled, _, validation = my_registry._fill(valid_config, ComplexSchema)
assert filled["outer_req"] == 1
assert filled["outer_opt"] == "default value"
assert filled["level2_req"]["hello"] == 4
assert filled["level2_req"]["world"] == 7
assert filled["level2_opt"]["required"] == 1
assert filled["level2_opt"]["optional"] == "default value"
def test_is_promise():
assert my_registry.is_promise(good_catsie)
assert not my_registry.is_promise({"hello": "world"})
assert not my_registry.is_promise(1)
invalid = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"}
assert my_registry.is_promise(invalid)
def test_get_constructor():
assert my_registry.get_constructor(good_catsie) == ("cats", "catsie.v1")
def test_parse_args():
args, kwargs = my_registry.parse_args(bad_catsie)
assert args == []
assert kwargs == {"evil": True, "cute": True}
def test_make_promise_schema():
schema = my_registry.make_promise_schema(good_catsie)
assert "evil" in schema.__fields__
assert "cute" in schema.__fields__
def test_validate_promise():
config = {"required": 1, "optional": good_catsie}
filled, _, validated = my_registry._fill(config, DefaultsSchema)
assert filled == config
assert validated == {"required": 1, "optional": "meow"}
def test_fill_validate_promise():
config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}}
filled, _, validated = my_registry._fill(config, DefaultsSchema)
assert filled["optional"]["cute"] is True
def test_fill_invalidate_promise():
config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}}
with pytest.raises(ConfigValidationError):
my_registry._fill(config, HelloIntsSchema)
config["optional"]["whiskers"] = True
with pytest.raises(ConfigValidationError):
my_registry._fill(config, DefaultsSchema)
def test_create_registry():
my_registry.dogs = catalogue.create(
my_registry.namespace, "dogs", entry_points=False
)
assert hasattr(my_registry, "dogs")
assert len(my_registry.dogs.get_all()) == 0
my_registry.dogs.register("good_boy.v1", func=lambda x: x)
assert len(my_registry.dogs.get_all()) == 1
def test_registry_methods():
with pytest.raises(ValueError):
my_registry.get("dfkoofkds", "catsie.v1")
my_registry.cats.register("catsie.v123")(None)
with pytest.raises(ValueError):
my_registry.get("cats", "catsie.v123")
def test_resolve_no_schema():
config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
result = my_registry.resolve({"cfg": config})["cfg"]
assert result["one"] == 1
assert result["two"] == {"three": "scratch!"}
with pytest.raises(ConfigValidationError):
config = {"two": {"three": {"@cats": "catsie.v1", "evil": "true"}}}
my_registry.resolve(config)
def test_resolve_schema():
class TestBaseSubSchema(BaseModel):
three: str
class TestBaseSchema(BaseModel):
one: PositiveInt
two: TestBaseSubSchema
class Config:
extra = "forbid"
class TestSchema(BaseModel):
cfg: TestBaseSchema
config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
my_registry.resolve({"cfg": config}, schema=TestSchema)
config = {"one": -1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
with pytest.raises(ConfigValidationError):
# "one" is not a positive int
my_registry.resolve({"cfg": config}, schema=TestSchema)
config = {"one": 1, "two": {"four": {"@cats": "catsie.v1", "evil": True}}}
with pytest.raises(ConfigValidationError):
# "three" is required in subschema
my_registry.resolve({"cfg": config}, schema=TestSchema)
def test_resolve_schema_coerced():
class TestBaseSchema(BaseModel):
test1: str
test2: bool
test3: float
class TestSchema(BaseModel):
cfg: TestBaseSchema
config = {"test1": 123, "test2": 1, "test3": 5}
filled = my_registry.fill({"cfg": config}, schema=TestSchema)
result = my_registry.resolve({"cfg": config}, schema=TestSchema)
assert result["cfg"] == {"test1": "123", "test2": True, "test3": 5.0}
# This only affects the resolved config, not the filled config
assert filled["cfg"] == config
def test_read_config():
byte_string = EXAMPLE_CONFIG.encode("utf8")
cfg = Config().from_bytes(byte_string)
assert cfg["optimizer"]["beta1"] == 0.9
assert cfg["optimizer"]["learn_rate"]["initial_rate"] == 0.1
assert cfg["pipeline"]["classifier"]["factory"] == "classifier"
assert cfg["pipeline"]["classifier"]["model"]["embedding"]["width"] == 128
def test_optimizer_config():
cfg = Config().from_str(OPTIMIZER_CFG)
optimizer = my_registry.resolve(cfg, validate=True)["optimizer"]
assert optimizer.beta1 == 0.9
def test_config_to_str():
cfg = Config().from_str(OPTIMIZER_CFG)
assert cfg.to_str().strip() == OPTIMIZER_CFG.strip()
cfg = Config({"optimizer": {"foo": "bar"}}).from_str(OPTIMIZER_CFG)
assert cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_to_str_creates_intermediate_blocks():
cfg = Config({"optimizer": {"foo": {"bar": 1}}})
assert (
cfg.to_str().strip()
== """
[optimizer]
[optimizer.foo]
bar = 1
""".strip()
)
def test_config_to_str_escapes():
section_str = """
[section]
node1 = "^a$$"
node2 = "$$b$$c"
"""
section_dict = {"section": {"node1": "^a$", "node2": "$b$c"}}
# parse from escaped string
cfg = Config().from_str(section_str)
assert cfg == section_dict
# parse from non-escaped dict
cfg = Config(section_dict)
assert cfg == section_dict
# roundtrip through str
cfg_str = cfg.to_str()
assert "^a$$" in cfg_str
new_cfg = Config().from_str(cfg_str)
assert cfg == section_dict
def test_config_roundtrip_bytes():
cfg = Config().from_str(OPTIMIZER_CFG)
cfg_bytes = cfg.to_bytes()
new_cfg = Config().from_bytes(cfg_bytes)
assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_roundtrip_disk():
cfg = Config().from_str(OPTIMIZER_CFG)
with make_tempdir() as path:
cfg_path = path / "config.cfg"
cfg.to_disk(cfg_path)
new_cfg = Config().from_disk(cfg_path)
assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture):
cfg = Config().from_str(OPTIMIZER_CFG)
cfg_path = pathy_fixture / "config.cfg"
cfg.to_disk(cfg_path)
new_cfg = Config().from_disk(cfg_path)
assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_to_str_invalid_defaults():
"""Test that an error is raised if a config contains top-level keys without
a section that would otherwise be interpreted as [DEFAULT] (which causes
the values to be included in *all* other sections).
"""
cfg = {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}
with pytest.raises(ConfigValidationError):
Config(cfg).to_str()
config_str = "[DEFAULT]\none = 1"
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
def test_validation_custom_types():
def complex_args(
rate: StrictFloat,
steps: PositiveInt = 10, # type: ignore
log_level: constr(regex="(DEBUG|INFO|WARNING|ERROR)") = "ERROR", # noqa: F821
):
return None
my_registry.complex = catalogue.create(
my_registry.namespace, "complex", entry_points=False
)
my_registry.complex("complex.v1")(complex_args)
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"}
my_registry.resolve({"config": cfg})
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": -1, "log_level": "INFO"}
with pytest.raises(ConfigValidationError):
# steps is not a positive int
my_registry.resolve({"config": cfg})
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "none"}
with pytest.raises(ConfigValidationError):
# log_level is not a string matching the regex
my_registry.resolve({"config": cfg})
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"}
with pytest.raises(ConfigValidationError):
# top-level object is promise
my_registry.resolve(cfg)
with pytest.raises(ConfigValidationError):
# top-level object is promise
my_registry.fill(cfg)
cfg = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"}
with pytest.raises(ConfigValidationError):
# two constructors
my_registry.resolve({"config": cfg})
def test_validation_no_validate():
config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": "false"}}}
result = my_registry.resolve({"cfg": config}, validate=False)
filled = my_registry.fill({"cfg": config}, validate=False)
assert result["cfg"]["one"] == 1
assert result["cfg"]["two"] == {"three": "scratch!"}
assert filled["cfg"]["two"]["three"]["evil"] == "false"
assert filled["cfg"]["two"]["three"]["cute"] is True
def test_validation_fill_defaults():
config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}}
result = my_registry.fill(config, validate=False)
assert len(result["cfg"]["two"]) == 3
with pytest.raises(ConfigValidationError):
# Required arg "evil" is not defined
my_registry.fill(config)
config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v2", "evil": False}}}
# Fill in with new defaults
result = my_registry.fill(config)
assert len(result["cfg"]["two"]) == 4
assert result["cfg"]["two"]["evil"] is False
assert result["cfg"]["two"]["cute"] is True
assert result["cfg"]["two"]["cute_level"] == 1
def test_make_config_positional_args():
@my_registry.cats("catsie.v567")
def catsie_567(*args: Optional[str], foo: str = "bar"):
assert args[0] == "^_^"
assert args[1] == "^(*.*)^"
assert foo == "baz"
return args[0]
args = ["^_^", "^(*.*)^"]
cfg = {"config": {"@cats": "catsie.v567", "foo": "baz", "*": args}}
assert my_registry.resolve(cfg)["config"] == "^_^"
def test_fill_config_positional_args_w_promise():
@my_registry.cats("catsie.v568")
def catsie_568(*args: str, foo: str = "bar"):
assert args[0] == "^(*.*)^"
assert foo == "baz"
return args[0]
@my_registry.cats("cat_promise.v568")
def cat_promise() -> str:
return "^(*.*)^"
cfg = {
"config": {
"@cats": "catsie.v568",
"*": {"promise": {"@cats": "cat_promise.v568"}},
}
}
filled = my_registry.fill(cfg, validate=True)
assert filled["config"]["foo"] == "bar"
assert filled["config"]["*"] == {"promise": {"@cats": "cat_promise.v568"}}
def test_make_config_positional_args_complex():
@my_registry.cats("catsie.v890")
def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]):
assert args[0] == 123
return args[0]
cfg = {"config": {"@cats": "catsie.v890", "*": [123, True, 1, False]}}
assert my_registry.resolve(cfg)["config"] == 123
cfg = {"config": {"@cats": "catsie.v890", "*": [123, "True"]}}
with pytest.raises(ConfigValidationError):
# "True" is not a valid boolean or positive int
my_registry.resolve(cfg)
def test_positional_args_to_from_string():
cfg = """[a]\nb = 1\n* = ["foo","bar"]"""
assert Config().from_str(cfg).to_str() == cfg
cfg = """[a]\nb = 1\n\n[a.*.bar]\ntest = 2\n\n[a.*.foo]\ntest = 1"""
assert Config().from_str(cfg).to_str() == cfg
@my_registry.cats("catsie.v666")
def catsie_666(*args, meow=False):
return args
cfg = """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]"""
filled = my_registry.fill(Config().from_str(cfg)).to_str()
assert filled == """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]\nmeow = false"""
resolved = my_registry.resolve(Config().from_str(cfg))
assert resolved == {"a": ("foo", "bar")}
cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\nx = 1"""
filled = my_registry.fill(Config().from_str(cfg)).to_str()
assert filled == """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\nx = 1"""
resolved = my_registry.resolve(Config().from_str(cfg))
assert resolved == {"a": ({"x": 1},)}
@my_registry.cats("catsie.v777")
def catsie_777(y: int = 1):
return "meow" * y
cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777\""""
filled = my_registry.fill(Config().from_str(cfg)).to_str()
expected = """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 1"""
assert filled == expected
cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 3"""
result = my_registry.resolve(Config().from_str(cfg))
assert result == {"a": ("meowmeowmeow",)}
def test_validation_generators_iterable():
@my_registry.optimizers("test_optimizer.v1")
def test_optimizer_v1(rate: float) -> None:
return None
@my_registry.schedules("test_schedule.v1")
def test_schedule_v1(some_value: float = 1.0) -> Iterable[float]:
while True:
yield some_value
config = {"optimizer": {"@optimizers": "test_optimizer.v1", "rate": 0.1}}
my_registry.resolve(config)
def test_validation_unset_type_hints():
"""Test that unset type hints are handled correctly (and treated as Any)."""
@my_registry.optimizers("test_optimizer.v2")
def test_optimizer_v2(rate, steps: int = 10) -> None:
return None
config = {"test": {"@optimizers": "test_optimizer.v2", "rate": 0.1, "steps": 20}}
my_registry.resolve(config)
def test_validation_bad_function():
@my_registry.optimizers("bad.v1")
def bad() -> None:
raise ValueError("This is an error in the function")
return None
@my_registry.optimizers("good.v1")
def good() -> None:
return None
# Bad function
config = {"test": {"@optimizers": "bad.v1"}}
with pytest.raises(ValueError):
my_registry.resolve(config)
# Bad function call
config = {"test": {"@optimizers": "good.v1", "invalid_arg": 1}}
with pytest.raises(ConfigValidationError):
my_registry.resolve(config)
def test_objects_from_config():
config = {
"optimizer": {
"@optimizers": "my_cool_optimizer.v1",
"beta1": 0.2,
"learn_rate": {
"@schedules": "my_cool_repetitive_schedule.v1",
"base_rate": 0.001,
"repeat": 4,
},
}
}
optimizer = my_registry.resolve(config)["optimizer"]
assert optimizer.beta1 == 0.2
assert optimizer.learn_rate == [0.001] * 4
def test_partials_from_config():
"""Test that functions registered with partial applications are handled
correctly (e.g. initializers)."""
numpy = pytest.importorskip("numpy")
def uniform_init(
shape: Tuple[int, ...], *, lo: float = -0.1, hi: float = 0.1
) -> List[float]:
return numpy.random.uniform(lo, hi, shape).tolist()
@my_registry.initializers("uniform_init.v1")
def configure_uniform_init(
*, lo: float = -0.1, hi: float = 0.1
) -> Callable[[List[float]], List[float]]:
return partial(uniform_init, lo=lo, hi=hi)
name = "uniform_init.v1"
cfg = {"test": {"@initializers": name, "lo": -0.2}}
func = my_registry.resolve(cfg)["test"]
assert hasattr(func, "__call__")
# The partial will still have lo as an arg, just with default
assert len(inspect.signature(func).parameters) == 3
# Make sure returned partial function has correct value set
assert inspect.signature(func).parameters["lo"].default == -0.2
# Actually call the function and verify
assert numpy.asarray(func((2, 3))).shape == (2, 3)
# Make sure validation still works
bad_cfg = {"test": {"@initializers": name, "lo": [0.5]}}
with pytest.raises(ConfigValidationError):
my_registry.resolve(bad_cfg)
bad_cfg = {"test": {"@initializers": name, "lo": -0.2, "other": 10}}
with pytest.raises(ConfigValidationError):
my_registry.resolve(bad_cfg)
def test_partials_from_config_nested():
"""Test that partial functions are passed correctly to other registered
functions that consume them (e.g. initializers -> layers)."""
def test_initializer(a: int, b: int = 1) -> int:
return a * b
@my_registry.initializers("test_initializer.v1")
def configure_test_initializer(b: int = 1) -> Callable[[int], int]:
return partial(test_initializer, b=b)
@my_registry.layers("test_layer.v1")
def test_layer(init: Callable[[int], int], c: int = 1) -> Callable[[int], int]:
return lambda x: x + init(c)
cfg = {
"@layers": "test_layer.v1",
"c": 5,
"init": {"@initializers": "test_initializer.v1", "b": 10},
}
func = my_registry.resolve({"test": cfg})["test"]
assert func(1) == 51
assert func(100) == 150
def test_validate_generator():
"""Test that generator replacement for validation in config doesn't
actually replace the returned value."""
@my_registry.schedules("test_schedule.v2")
def test_schedule():
while True:
yield 10
cfg = {"@schedules": "test_schedule.v2"}
result = my_registry.resolve({"test": cfg})["test"]
assert isinstance(result, GeneratorType)
@my_registry.optimizers("test_optimizer.v2")
def test_optimizer2(rate: Generator) -> Generator:
return rate
cfg = {
"@optimizers": "test_optimizer.v2",
"rate": {"@schedules": "test_schedule.v2"},
}
result = my_registry.resolve({"test": cfg})["test"]
assert isinstance(result, GeneratorType)
@my_registry.optimizers("test_optimizer.v3")
def test_optimizer3(schedules: Dict[str, Generator]) -> Generator:
return schedules["rate"]
cfg = {
"@optimizers": "test_optimizer.v3",
"schedules": {"rate": {"@schedules": "test_schedule.v2"}},
}
result = my_registry.resolve({"test": cfg})["test"]
assert isinstance(result, GeneratorType)
@my_registry.optimizers("test_optimizer.v4")
def test_optimizer4(*schedules: Generator) -> Generator:
return schedules[0]
def test_handle_generic_type():
"""Test that validation can handle checks against arbitrary generic
types in function argument annotations."""
cfg = {"@cats": "generic_cat.v1", "cat": {"@cats": "int_cat.v1", "value_in": 3}}
cat = my_registry.resolve({"test": cfg})["test"]
assert isinstance(cat, Cat)
assert cat.value_in == 3
assert cat.value_out is None
assert cat.name == "generic_cat"
@pytest.mark.parametrize(
"cfg",
[
"[a]\nb = 1\nc = 2\n\n[a.c]\nd = 3",
"[a]\nb = 1\n\n[a.c]\nd = 2\n\n[a.c.d]\ne = 3",
],
)
def test_handle_error_duplicate_keys(cfg):
"""This would cause very cryptic error when interpreting config.
(TypeError: 'X' object does not support item assignment)
"""
with pytest.raises(ConfigValidationError):
Config().from_str(cfg)
@pytest.mark.parametrize(
"cfg,is_valid",
[("[a]\nb = 1\n\n[a.c]\nd = 3", True), ("[a]\nb = 1\n\n[A.c]\nd = 2", False)],
)
def test_cant_expand_undefined_block(cfg, is_valid):
"""Test that you can't expand a block that hasn't been created yet. This
comes up when you typo a name, and if we allow expansion of undefined blocks,
it's very hard to create good errors for those typos.
"""
if is_valid:
Config().from_str(cfg)
else:
with pytest.raises(ConfigValidationError):
Config().from_str(cfg)
def test_fill_config_overrides():
config = {
"cfg": {
"one": 1,
"two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}},
}
}
overrides = {"cfg.two.three.evil": False}
result = my_registry.fill(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"]["evil"] is False
# Test that promises can be overwritten as well
overrides = {"cfg.two.three": 3}
result = my_registry.fill(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"] == 3
# Test that value can be overwritten with promises and that the result is
# interpreted and filled correctly
overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}}
result = my_registry.fill(config, overrides=overrides)
assert result["cfg"]["two"] is None
assert result["cfg"]["one"]["@cats"] == "catsie.v1"
assert result["cfg"]["one"]["evil"] is False
assert result["cfg"]["one"]["cute"] is True
# Overwriting with wrong types should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": 20}
my_registry.fill(config, overrides=overrides, validate=True)
# Overwriting with incomplete promises should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}}
my_registry.fill(config, overrides=overrides)
# Overrides that don't match config should raise error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": False, "two.four": True}
my_registry.fill(config, overrides=overrides, validate=True)
with pytest.raises(ConfigValidationError):
overrides = {"cfg.five": False}
my_registry.fill(config, overrides=overrides, validate=True)
def test_resolve_overrides():
config = {
"cfg": {
"one": 1,
"two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}},
}
}
overrides = {"cfg.two.three.evil": False}
result = my_registry.resolve(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"] == "meow"
# Test that promises can be overwritten as well
overrides = {"cfg.two.three": 3}
result = my_registry.resolve(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"] == 3
# Test that value can be overwritten with promises
overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}}
result = my_registry.resolve(config, overrides=overrides)
assert result["cfg"]["one"] == "meow"
assert result["cfg"]["two"] is None
# Overwriting with wrong types should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": 20}
my_registry.resolve(config, overrides=overrides, validate=True)
# Overwriting with incomplete promises should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}}
my_registry.resolve(config, overrides=overrides)
# Overrides that don't match config should raise error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": False, "cfg.two.four": True}
my_registry.resolve(config, overrides=overrides, validate=True)
with pytest.raises(ConfigValidationError):
overrides = {"cfg.five": False}
my_registry.resolve(config, overrides=overrides, validate=True)
@pytest.mark.parametrize(
"prop,expected",
[("a.b.c", True), ("a.b", True), ("a", True), ("a.e", True), ("a.b.c.d", False)],
)
def test_is_in_config(prop, expected):
config = {"a": {"b": {"c": 5, "d": 6}, "e": [1, 2]}}
assert my_registry._is_in_config(prop, config) is expected
def test_resolve_prefilled_values():
class Language(object):
def __init__(self):
...
@my_registry.optimizers("prefilled.v1")
def prefilled(nlp: Language, value: int = 10):
return (nlp, value)
# Passing an instance of Language here via the config is bad, since it
# won't serialize to a string, but we still test for it
config = {"test": {"@optimizers": "prefilled.v1", "nlp": Language(), "value": 50}}
resolved = my_registry.resolve(config, validate=True)
result = resolved["test"]
assert isinstance(result[0], Language)
assert result[1] == 50
def test_fill_config_dict_return_type():
"""Test that a registered function returning a dict is handled correctly."""
@my_registry.cats.register("catsie_with_dict.v1")
def catsie_with_dict(evil: StrictBool) -> Dict[str, bool]:
return {"not_evil": not evil}
config = {"test": {"@cats": "catsie_with_dict.v1", "evil": False}, "foo": 10}
result = my_registry.fill({"cfg": config}, validate=True)["cfg"]["test"]
assert result["evil"] is False
assert "not_evil" not in result
result = my_registry.resolve({"cfg": config}, validate=True)["cfg"]["test"]
assert result["not_evil"] is True
def test_deepcopy_config():
config = Config({"a": 1, "b": {"c": 2, "d": 3}})
copied = config.copy()
# Same values but not same object
assert config == copied
assert config is not copied
@pytest.mark.skipif(
platform.python_implementation() == "PyPy", reason="copy does not fail for pypy"
)
def test_deepcopy_config_pickle():
numpy = pytest.importorskip("numpy")
# Check for error if value can't be pickled/deepcopied
config = Config({"a": 1, "b": numpy})
with pytest.raises(ValueError):
config.copy()
def test_config_to_str_simple_promises():
"""Test that references to function registries without arguments are
serialized inline as dict."""
config_str = """[section]\nsubsection = {"@registry":"value"}"""
config = Config().from_str(config_str)
assert config["section"]["subsection"]["@registry"] == "value"
assert config.to_str() == config_str
def test_config_from_str_invalid_section():
config_str = """[a]\nb = null\n\n[a.b]\nc = 1"""
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
config_str = """[a]\nb = null\n\n[a.b.c]\nd = 1"""
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
def test_config_to_str_order():
"""Test that Config.to_str orders the sections."""
config = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": {"g": {"h": {"i": 4, "j": 5}}}}
expected = (
"[a]\ne = 3\n\n[a.b]\nc = 1\nd = 2\n\n[f]\n\n[f.g]\n\n[f.g.h]\ni = 4\nj = 5"
)
config = Config(config)
assert config.to_str() == expected
@pytest.mark.parametrize("d", [".", ":"])
def test_config_interpolation(d):
"""Test that config values are interpolated correctly. The parametrized
value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
valid. The double {{ }} in the config strings are required to prevent the
references from being interpreted as an actual f-string variable.
"""
c_str = """[a]\nfoo = "hello"\n\n[b]\nbar = ${foo}"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}"""
assert Config().from_str(c_str)["b"]["bar"] == "hello"
c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}!"""
assert Config().from_str(c_str)["b"]["bar"] == "hello!"
c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = "${{a{d}foo}}!\""""
assert Config().from_str(c_str)["b"]["bar"] == "hello!"
c_str = f"""[a]\nfoo = 15\n\n[b]\nbar = ${{a{d}foo}}!"""
assert Config().from_str(c_str)["b"]["bar"] == "15!"
c_str = f"""[a]\nfoo = ["x", "y"]\n\n[b]\nbar = ${{a{d}foo}}"""
assert Config().from_str(c_str)["b"]["bar"] == ["x", "y"]
# Interpolation within the same section
c_str = f"""[a]\nfoo = "x"\nbar = ${{a{d}foo}}\nbaz = "${{a{d}foo}}y\""""
assert Config().from_str(c_str)["a"]["bar"] == "x"
assert Config().from_str(c_str)["a"]["baz"] == "xy"
def test_config_interpolation_lists():
# Test that lists are preserved correctly
c_str = """[a]\nb = 1\n\n[c]\nd = ["hello ${a.b}", "world"]"""
config = Config().from_str(c_str, interpolate=False)
assert config["c"]["d"] == ["hello ${a.b}", "world"]
config = config.interpolate()
assert config["c"]["d"] == ["hello 1", "world"]
c_str = """[a]\nb = 1\n\n[c]\nd = [${a.b}, "hello ${a.b}", "world"]"""
config = Config().from_str(c_str)
assert config["c"]["d"] == [1, "hello 1", "world"]
config = Config().from_str(c_str, interpolate=False)
# NOTE: This currently doesn't work, because we can't know how to JSON-load
# the uninterpolated list [${a.b}].
# assert config["c"]["d"] == ["${a.b}", "hello ${a.b}", "world"]
# config = config.interpolate()
# assert config["c"]["d"] == [1, "hello 1", "world"]
c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", ${a}]"""
config = Config().from_str(c_str)
assert config["c"]["d"] == ["hello", {"b": 1}]
c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", "hello ${a}"]"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": ["hello ${a.b}"], "y": 2}]"""
config = Config().from_str(config_str)
assert config["c"]["d"] == ["hello", {"x": ["hello 1"], "y": 2}]
config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": [${a.b}], "y": 2}]"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
@pytest.mark.parametrize("d", [".", ":"])
def test_config_interpolation_sections(d):
"""Test that config sections are interpolated correctly. The parametrized
value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
valid. The double {{ }} in the config strings are required to prevent the
references from being interpreted as an actual f-string variable.
"""
# Simple block references
c_str = """[a]\nfoo = "hello"\nbar = "world"\n\n[b]\nc = ${a}"""
config = Config().from_str(c_str)
assert config["b"]["c"] == config["a"]
# References with non-string values
c_str = f"""[a]\nfoo = "hello"\n\n[a.x]\ny = ${{a{d}b}}\n\n[a.b]\nc = 1\nd = [10]"""
config = Config().from_str(c_str)
assert config["a"]["x"]["y"] == config["a"]["b"]
# Multiple references in the same string
c_str = f"""[a]\nx = "string"\ny = 10\n\n[b]\nz = "${{a{d}x}}/${{a{d}y}}\""""
config = Config().from_str(c_str)
assert config["b"]["z"] == "string/10"
# Non-string references in string (converted to string)
c_str = f"""[a]\nx = ["hello", "world"]\n\n[b]\ny = "result: ${{a{d}x}}\""""
config = Config().from_str(c_str)
assert config["b"]["y"] == 'result: ["hello", "world"]'
# References to sections referencing sections
c_str = """[a]\nfoo = "x"\n\n[b]\nbar = ${a}\n\n[c]\nbaz = ${b}"""
config = Config().from_str(c_str)
assert config["b"]["bar"] == config["a"]
assert config["c"]["baz"] == config["b"]
# References to section values referencing other sections
c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b{d}bar}}"""
config = Config().from_str(c_str)
assert config["c"]["baz"] == config["b"]["bar"]
# References to sections with subsections
c_str = """[a]\nfoo = "x"\n\n[a.b]\nbar = 100\n\n[c]\nbaz = ${a}"""
config = Config().from_str(c_str)
assert config["c"]["baz"] == config["a"]
# Infinite recursion
c_str = """[a]\nfoo ="x"\n\n[a.b]\nbar = ${a}"""
config = Config().from_str(c_str)
assert config["a"]["b"]["bar"] == config["a"]
c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b.bar{d}foo}}"""
# We can't reference not-yet interpolated subsections
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
# Generally invalid references
c_str = f"""[a]\nfoo = ${{b{d}bar}}"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
# We can't reference sections or promises within strings
c_str = """[a]\n\n[a.b]\nfoo = "x: ${c}"\n\n[c]\nbar = 1\nbaz = 2"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
def test_config_from_str_overrides():
config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = 3\n\n[f]\ng = {"x": "y"}"""
# Basic value substitution
overrides = {"a.b": 10, "a.c.d": 20}
config = Config().from_str(config_str, overrides=overrides)
assert config["a"]["b"] == 10
assert config["a"]["c"]["d"] == 20
assert config["a"]["c"]["e"] == 3
# Valid values that previously weren't in config
config = Config().from_str(config_str, overrides={"a.c.f": 100})
assert config["a"]["c"]["d"] == 2
assert config["a"]["c"]["e"] == 3
assert config["a"]["c"]["f"] == 100
# Invalid keys and sections
with pytest.raises(ConfigValidationError):
Config().from_str(config_str, overrides={"f": 10})
# This currently isn't expected to work, because the dict in f.g is not
# interpreted as a section while the config is still just the configparser
with pytest.raises(ConfigValidationError):
Config().from_str(config_str, overrides={"f.g.x": "z"})
# With variables (values)
config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = ${a:b}"""
config = Config().from_str(config_str, overrides={"a.b": 10})
assert config["a"]["b"] == 10
assert config["a"]["c"]["e"] == 10
# With variables (sections)
config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\n[e]\nf = ${a.c}"""
config = Config().from_str(config_str, overrides={"a.c.d": 20})
assert config["a"]["c"]["d"] == 20
assert config["e"]["f"] == {"d": 20}
def test_config_reserved_aliases():
"""Test that the auto-generated pydantic schemas auto-alias reserved
attributes like "validate" that would otherwise cause NameError."""
@my_registry.cats("catsie.with_alias")
def catsie_with_alias(validate: StrictBool = False):
return validate
cfg = {"@cats": "catsie.with_alias", "validate": True}
resolved = my_registry.resolve({"test": cfg})
filled = my_registry.fill({"test": cfg})
assert resolved["test"] is True
assert filled["test"] == cfg
cfg = {"@cats": "catsie.with_alias", "validate": 20}
with pytest.raises(ConfigValidationError):
my_registry.resolve({"test": cfg})
@pytest.mark.parametrize("d", [".", ":"])
def test_config_no_interpolation(d):
"""Test that interpolation is correctly preserved. The parametrized
value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
valid. The double {{ }} in the config strings are required to prevent the
references from being interpreted as an actual f-string variable.
"""
numpy = pytest.importorskip("numpy")
c_str = f"""[a]\nb = 1\n\n[c]\nd = ${{a{d}b}}\ne = \"hello${{a{d}b}}"\nf = ${{a}}"""
config = Config().from_str(c_str, interpolate=False)
assert not config.is_interpolated
assert config["c"]["d"] == f"${{a{d}b}}"
assert config["c"]["e"] == f'"hello${{a{d}b}}"'
assert config["c"]["f"] == "${a}"
config2 = Config().from_str(config.to_str(), interpolate=True)
assert config2.is_interpolated
assert config2["c"]["d"] == 1
assert config2["c"]["e"] == "hello1"
assert config2["c"]["f"] == {"b": 1}
config3 = config.interpolate()
assert config3.is_interpolated
assert config3["c"]["d"] == 1
assert config3["c"]["e"] == "hello1"
assert config3["c"]["f"] == {"b": 1}
# Bad non-serializable value
cfg = {"x": {"y": numpy.asarray([[1, 2], [4, 5]], dtype="f"), "z": f"${{x{d}y}}"}}
with pytest.raises(ConfigValidationError):
Config(cfg).interpolate()
def test_config_no_interpolation_registry():
config_str = """[a]\nbad = true\n[b]\n@cats = "catsie.v1"\nevil = ${a:bad}\n\n[c]\n d = ${b}"""
config = Config().from_str(config_str, interpolate=False)
assert not config.is_interpolated
assert config["b"]["evil"] == "${a:bad}"
assert config["c"]["d"] == "${b}"
filled = my_registry.fill(config)
resolved = my_registry.resolve(config)
assert resolved["b"] == "scratch!"
assert resolved["c"]["d"] == "scratch!"
assert filled["b"]["evil"] == "${a:bad}"
assert filled["b"]["cute"] is True
assert filled["c"]["d"] == "${b}"
interpolated = filled.interpolate()
assert interpolated.is_interpolated
assert interpolated["b"]["evil"] is True
assert interpolated["c"]["d"] == interpolated["b"]
config = Config().from_str(config_str, interpolate=True)
assert config.is_interpolated
filled = my_registry.fill(config)
resolved = my_registry.resolve(config)
assert resolved["b"] == "scratch!"
assert resolved["c"]["d"] == "scratch!"
assert filled["b"]["evil"] is True
assert filled["c"]["d"] == filled["b"]
# Resolving a non-interpolated filled config
config = Config().from_str(config_str, interpolate=False)
assert not config.is_interpolated
filled = my_registry.fill(config)
assert not filled.is_interpolated
assert filled["c"]["d"] == "${b}"
resolved = my_registry.resolve(filled)
assert resolved["c"]["d"] == "scratch!"
def test_config_deep_merge():
config = {"a": "hello", "b": {"c": "d"}}
defaults = {"a": "world", "b": {"c": "e", "f": "g"}}
merged = Config(defaults).merge(config)
assert len(merged) == 2
assert merged["a"] == "hello"
assert merged["b"] == {"c": "d", "f": "g"}
config = {"a": "hello", "b": {"@test": "x", "foo": 1}}
defaults = {"a": "world", "b": {"@test": "x", "foo": 100, "bar": 2}, "c": 100}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@test": "x", "foo": 1, "bar": 2}
assert merged["c"] == 100
config = {"a": "hello", "b": {"@test": "x", "foo": 1}, "c": 100}
defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@test": "x", "foo": 1}
assert merged["c"] == 100
# Test that leaving out the factory just adds to existing
config = {"a": "hello", "b": {"foo": 1}, "c": 100}
defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@test": "y", "foo": 1, "bar": 2}
assert merged["c"] == 100
# Test that switching to a different factory prevents the default from being added
config = {"a": "hello", "b": {"@foo": 1}, "c": 100}
defaults = {"a": "world", "b": {"@bar": "y"}}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@foo": 1}
assert merged["c"] == 100
config = {"a": "hello", "b": {"@foo": 1}, "c": 100}
defaults = {"a": "world", "b": "y"}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@foo": 1}
assert merged["c"] == 100
def test_config_deep_merge_variables():
config_str = """[a]\nb= 1\nc = 2\n\n[d]\ne = ${a:b}"""
defaults_str = """[a]\nx = 100\n\n[d]\ny = 500"""
config = Config().from_str(config_str, interpolate=False)
defaults = Config().from_str(defaults_str)
merged = defaults.merge(config)
assert merged["a"] == {"b": 1, "c": 2, "x": 100}
assert merged["d"] == {"e": "${a:b}", "y": 500}
assert merged.interpolate()["d"] == {"e": 1, "y": 500}
# With variable in defaults: overwritten by new value
config = Config().from_str("""[a]\nb= 1\nc = 2""")
defaults = Config().from_str("""[a]\nb = 100\nc = ${a:b}""", interpolate=False)
merged = defaults.merge(config)
assert merged["a"]["c"] == 2
def test_config_to_str_roundtrip():
numpy = pytest.importorskip("numpy")
cfg = {"cfg": {"foo": False}}
config_str = Config(cfg).to_str()
assert config_str == "[cfg]\nfoo = false"
config = Config().from_str(config_str)
assert dict(config) == cfg
cfg = {"cfg": {"foo": "false"}}
config_str = Config(cfg).to_str()
assert config_str == '[cfg]\nfoo = "false"'
config = Config().from_str(config_str)
assert dict(config) == cfg
# Bad non-serializable value
cfg = {"cfg": {"x": numpy.asarray([[1, 2, 3, 4], [4, 5, 3, 4]], dtype="f")}}
config = Config(cfg)
with pytest.raises(ConfigValidationError):
config.to_str()
# Roundtrip with variables: preserve variables correctly (quoted/unquoted)
config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = "${a:b}\""""
config = Config().from_str(config_str, interpolate=False)
assert config.to_str() == config_str
def test_config_is_interpolated():
"""Test that a config object correctly reports whether it's interpolated."""
config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = ${a}"""
config = Config().from_str(config_str, interpolate=False)
assert not config.is_interpolated
config = config.merge(Config({"x": {"y": "z"}}))
assert not config.is_interpolated
config = Config(config)
assert not config.is_interpolated
config = config.interpolate()
assert config.is_interpolated
config = config.merge(Config().from_str(config_str, interpolate=False))
assert not config.is_interpolated
@pytest.mark.parametrize(
"section_order,expected_str,expected_keys",
[
# fmt: off
([], "[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[h]\ni = 5\n\n[j]\nk = 6", ["a", "h", "j"]),
(["j", "h", "a"], "[j]\nk = 6\n\n[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4", ["j", "h", "a"]),
(["h"], "[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[j]\nk = 6", ["h", "a", "j"])
# fmt: on
],
)
def test_config_serialize_custom_sort(section_order, expected_str, expected_keys):
cfg = {
"j": {"k": 6},
"a": {"b": 1, "d": {"e": 3}, "c": 2, "f": {"g": 4}},
"h": {"i": 5},
}
cfg_str = Config(cfg).to_str()
assert Config(cfg, section_order=section_order).to_str() == expected_str
keys = list(Config(section_order=section_order).from_str(cfg_str).keys())
assert keys == expected_keys
keys = list(Config(cfg, section_order=section_order).keys())
assert keys == expected_keys
def test_config_custom_sort_preserve():
"""Test that sort order is preserved when merging and copying configs,
or when configs are filled and resolved."""
cfg = {"x": {}, "y": {}, "z": {}}
section_order = ["y", "z", "x"]
expected = "[y]\n\n[z]\n\n[x]"
config = Config(cfg, section_order=section_order)
assert config.to_str() == expected
config2 = config.copy()
assert config2.to_str() == expected
config3 = config.merge({"a": {}})
assert config3.to_str() == f"{expected}\n\n[a]"
config4 = Config(config)
assert config4.to_str() == expected
config_str = """[a]\nb = 1\n[c]\n@cats = "catsie.v1"\nevil = true\n\n[t]\n x = 2"""
section_order = ["c", "a", "t"]
config5 = Config(section_order=section_order).from_str(config_str)
assert list(config5.keys()) == section_order
filled = my_registry.fill(config5)
assert filled.section_order == section_order
def test_config_pickle():
config = Config({"foo": "bar"}, section_order=["foo", "bar", "baz"])
data = pickle.dumps(config)
config_new = pickle.loads(data)
assert config_new == {"foo": "bar"}
assert config_new.section_order == ["foo", "bar", "baz"]
def test_config_fill_extra_fields():
"""Test that filling a config from a schema removes extra fields."""
class TestSchemaContent(BaseModel):
a: str
b: int
class Config:
extra = "forbid"
class TestSchema(BaseModel):
cfg: TestSchemaContent
config = Config({"cfg": {"a": "1", "b": 2, "c": True}})
with pytest.raises(ConfigValidationError):
my_registry.fill(config, schema=TestSchema)
filled = my_registry.fill(config, schema=TestSchema, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2}
config2 = config.interpolate()
filled = my_registry.fill(config2, schema=TestSchema, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2}
config3 = Config({"cfg": {"a": "1", "b": 2, "c": True}}, is_interpolated=False)
filled = my_registry.fill(config3, schema=TestSchema, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2}
class TestSchemaContent2(BaseModel):
a: str
b: int
class Config:
extra = "allow"
class TestSchema2(BaseModel):
cfg: TestSchemaContent2
filled = my_registry.fill(config, schema=TestSchema2, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2, "c": True}
def test_config_validation_error_custom():
class Schema(BaseModel):
hello: int
world: int
config = {"hello": 1, "world": "hi!"}
with pytest.raises(ConfigValidationError) as exc_info:
my_registry._fill(config, Schema)
e1 = exc_info.value
assert e1.title == "Config validation error"
assert e1.desc is None
assert not e1.parent
assert e1.show_config is True
assert len(e1.errors) == 1
assert e1.errors[0]["loc"] == ("world",)
assert e1.errors[0]["msg"] == "value is not a valid integer"
assert e1.errors[0]["type"] == "type_error.integer"
assert e1.error_types == set(["type_error.integer"])
# Create a new error with overrides
title = "Custom error"
desc = "Some error description here"
e2 = ConfigValidationError.from_error(e1, title=title, desc=desc, show_config=False)
assert e2.errors == e1.errors
assert e2.error_types == e1.error_types
assert e2.title == title
assert e2.desc == desc
assert e2.show_config is False
assert e1.text != e2.text
def test_config_parsing_error():
config_str = "[a]\nb c"
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
def test_config_fill_without_resolve():
class BaseSchema(BaseModel):
catsie: int
config = {"catsie": {"@cats": "catsie.v1", "evil": False}}
filled = my_registry.fill(config)
resolved = my_registry.resolve(config)
assert resolved["catsie"] == "meow"
assert filled["catsie"]["cute"] is True
with pytest.raises(ConfigValidationError):
my_registry.resolve(config, schema=BaseSchema)
filled2 = my_registry.fill(config, schema=BaseSchema)
assert filled2["catsie"]["cute"] is True
resolved = my_registry.resolve(filled2)
assert resolved["catsie"] == "meow"
# With unavailable function
class BaseSchema2(BaseModel):
catsie: Any
other: int = 12
config = {"catsie": {"@cats": "dog", "evil": False}}
filled3 = my_registry.fill(config, schema=BaseSchema2)
assert filled3["catsie"] == config["catsie"]
assert filled3["other"] == 12
def test_config_dataclasses():
cat = Cat("testcat", value_in=1, value_out=2)
config = {"cfg": {"@cats": "catsie.v3", "arg": cat}}
result = my_registry.resolve(config)["cfg"]
assert isinstance(result, Cat)
assert result.name == cat.name
assert result.value_in == cat.value_in
assert result.value_out == cat.value_out
@pytest.mark.parametrize(
"greeting,value,expected",
[
# simple substitution should go fine
[342, "${vars.a}", int],
["342", "${vars.a}", str],
["everyone", "${vars.a}", str],
],
)
def test_config_interpolates(greeting, value, expected):
str_cfg = f"""
[project]
my_par = {value}
[vars]
a = "something"
"""
overrides = {"vars.a": greeting}
cfg = Config().from_str(str_cfg, overrides=overrides)
assert type(cfg["project"]["my_par"]) == expected
@pytest.mark.parametrize(
"greeting,value,expected",
[
# fmt: off
# simple substitution should go fine
["hello 342", "${vars.a}", "hello 342"],
["hello everyone", "${vars.a}", "hello everyone"],
["hello tout le monde", "${vars.a}", "hello tout le monde"],
["hello 42", "${vars.a}", "hello 42"],
# substituting an element in a list
["hello 342", "[1, ${vars.a}, 3]", "hello 342"],
["hello everyone", "[1, ${vars.a}, 3]", "hello everyone"],
["hello tout le monde", "[1, ${vars.a}, 3]", "hello tout le monde"],
["hello 42", "[1, ${vars.a}, 3]", "hello 42"],
# substituting part of a string
[342, "hello ${vars.a}", "hello 342"],
["everyone", "hello ${vars.a}", "hello everyone"],
["tout le monde", "hello ${vars.a}", "hello tout le monde"],
pytest.param("42", "hello ${vars.a}", "hello 42", marks=pytest.mark.xfail),
# substituting part of a implicit string inside a list
[342, "[1, hello ${vars.a}, 3]", "hello 342"],
["everyone", "[1, hello ${vars.a}, 3]", "hello everyone"],
["tout le monde", "[1, hello ${vars.a}, 3]", "hello tout le monde"],
pytest.param("42", "[1, hello ${vars.a}, 3]", "hello 42", marks=pytest.mark.xfail),
# substituting part of a explicit string inside a list
[342, "[1, 'hello ${vars.a}', '3']", "hello 342"],
["everyone", "[1, 'hello ${vars.a}', '3']", "hello everyone"],
["tout le monde", "[1, 'hello ${vars.a}', '3']", "hello tout le monde"],
pytest.param("42", "[1, 'hello ${vars.a}', '3']", "hello 42", marks=pytest.mark.xfail),
# more complicated example
[342, "[{'name':'x','script':['hello ${vars.a}']}]", "hello 342"],
["everyone", "[{'name':'x','script':['hello ${vars.a}']}]", "hello everyone"],
["tout le monde", "[{'name':'x','script':['hello ${vars.a}']}]", "hello tout le monde"],
pytest.param("42", "[{'name':'x','script':['hello ${vars.a}']}]", "hello 42", marks=pytest.mark.xfail),
# fmt: on
],
)
def test_config_overrides(greeting, value, expected):
str_cfg = f"""
[project]
commands = {value}
[vars]
a = "world"
"""
overrides = {"vars.a": greeting}
assert "${vars.a}" in str_cfg
cfg = Config().from_str(str_cfg, overrides=overrides)
assert expected in str(cfg)
def test_warn_single_quotes():
str_cfg = """
[project]
commands = 'do stuff'
"""
with pytest.warns(UserWarning, match="single-quoted"):
Config().from_str(str_cfg)
# should not warn if single quotes are in the middle
str_cfg = """
[project]
commands = some'thing
"""
Config().from_str(str_cfg)
def test_parse_strings_interpretable_as_ints():
"""Test whether strings interpretable as integers are parsed correctly (i. e. as strings)."""
cfg = Config().from_str(
f"""[a]\nfoo = [${{b.bar}}, "00${{b.bar}}", "y"]\n\n[b]\nbar = 3""" # noqa: F541
)
assert cfg["a"]["foo"] == [3, "003", "y"]
assert cfg["b"]["bar"] == 3