Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 09b9170

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Implement config validation to find unused keys
Summary: Implement a `ClassyMap` type which supports tracking reads and freezing the map (the latter is unused currently). Added it to `ClassificationTask` to catch cases where we don't use any keys passed by users. This will not catch all instances, like when some components do a deepcopy - we assume all the keys and sub-keys are read in such a situation Differential Revision: D25321360 fbshipit-source-id: aff06f1b3334ca9d217453d590ac413b4f586966
1 parent bd5c260 commit 09b9170

File tree

6 files changed

+313
-5
lines changed

6 files changed

+313
-5
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .classy_map import ClassyMap
7+
from .config_error import ConfigError, ConfigUnusedKeysError
8+
9+
__all__ = ["ClassyMap", "ConfigError", "ConfigUnusedKeysError"]
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
from collections.abc import MutableMapping, Mapping
8+
9+
10+
class ClassyMap(MutableMapping):
11+
"""Mapping which can be made immutable. Also supports tracking unused keys."""
12+
13+
def __init__(self, *args, **kwargs):
14+
"""Create a ClassyMap.
15+
16+
Supports the same API as a dict and recursively converts all dicts to
17+
ClassyMaps.
18+
"""
19+
20+
# NOTE: Another way to implement this would be to subclass dict, but since dict
21+
# is a built-in, it isn't treated like a regular MutableMapping, and calls like
22+
# func(**map) are handled mysteriously, probably interpreter dependent.
23+
# The downside with this implementation is that this isn't a full dict and is
24+
# just a mapping, which means some features like JSON serialization don't work
25+
26+
self._dict = dict(*args, **kwargs)
27+
self._frozen = False
28+
self._keys_read = set()
29+
for k, v in self._dict.items():
30+
self._dict[k] = self.to_classy_dict(v)
31+
32+
def to_classy_dict(self, obj):
33+
"""Recursively convert all sub items inside obj to ClassyMaps"""
34+
35+
if isinstance(obj, Mapping):
36+
obj = ClassyMap({k: self.to_classy_dict(v) for k, v in obj.items()})
37+
elif isinstance(obj, (list, tuple)):
38+
# tuples are also converted to lists
39+
obj = [self.to_classy_dict(v) for v in obj]
40+
return obj
41+
42+
def keys(self):
43+
return self._dict.keys()
44+
45+
def items(self):
46+
self._keys_read.update(self._dict.keys())
47+
return self._dict.items()
48+
49+
def values(self):
50+
self._keys_read.update(self._dict.keys())
51+
return self._dict.values()
52+
53+
def pop(self, key, default=None):
54+
return self._dict.pop(key, default)
55+
56+
def popitem(self):
57+
return self._dict.popitem()
58+
59+
def clear(self):
60+
self._dict.clear()
61+
62+
def update(self, *args, **kwargs):
63+
if self._frozen:
64+
raise TypeError("Frozen ClassyMaps do not support updates")
65+
self._dict.update(*args, **kwargs)
66+
67+
def setdefault(self, key, default=None):
68+
return self._dict.setdefault(key, default)
69+
70+
def __contains__(self, key):
71+
return key in self._dict
72+
73+
def __eq__(self, obj):
74+
return self._dict == obj
75+
76+
def __len__(self):
77+
return len(self._dict)
78+
79+
def __getitem__(self, key):
80+
self._keys_read.add(key)
81+
return self._dict.__getitem__(key)
82+
83+
def __iter__(self):
84+
return iter(self._dict)
85+
86+
def __str__(self):
87+
return str(self._dict)
88+
89+
def __repr__(self):
90+
return repr(self._dict)
91+
92+
def get(self, key, default=None):
93+
if key in self._dict.keys():
94+
self._keys_read.add(key)
95+
return self._dict.get(key, default)
96+
97+
def __copy__(self):
98+
ret = ClassyMap()
99+
for key, value in self._dict.items():
100+
self._keys_read.add(key)
101+
ret._dict[key] = value
102+
103+
def copy(self):
104+
return self.__copy__()
105+
106+
def __deepcopy__(self, memo=None):
107+
# for deepcopies we mark all the keys and sub-keys as read
108+
ret = ClassyMap()
109+
for key, value in self._dict.items():
110+
self._keys_read.add(key)
111+
ret._dict[key] = copy.deepcopy(value)
112+
return ret
113+
114+
def __setitem__(self, key, value):
115+
if self._frozen:
116+
raise TypeError("Frozen ClassyMaps do not support assignment")
117+
if isinstance(value, dict) and not isinstance(value, ClassyMap):
118+
value = ClassyMap(value)
119+
self._dict.__setitem__(key, value)
120+
121+
def __delitem__(self, key):
122+
if self._frozen:
123+
raise TypeError("Frozen ClassyMaps do not support key deletion")
124+
del self._dict[key]
125+
126+
def _freeze(self, obj):
127+
if isinstance(obj, Mapping):
128+
assert isinstance(obj, ClassyMap), f"{obj} is not a ClassyMap"
129+
obj._frozen = True
130+
for value in obj.values():
131+
self._freeze(value)
132+
elif isinstance(obj, list):
133+
for value in obj:
134+
self._freeze(value)
135+
136+
def _reset_tracking(self, obj):
137+
if isinstance(obj, Mapping):
138+
assert isinstance(obj, ClassyMap), f"{obj} is not a ClassyMap"
139+
obj._keys_read = set()
140+
for value in obj._dict.values():
141+
self._reset_tracking(value)
142+
elif isinstance(obj, list):
143+
for value in obj:
144+
self._reset_tracking(value)
145+
146+
def _unused_keys(self, obj):
147+
unused_keys = []
148+
if isinstance(obj, Mapping):
149+
assert isinstance(obj, ClassyMap), f"{obj} is not a ClassyMap"
150+
unused_keys = [key for key in obj._dict.keys() if key not in obj._keys_read]
151+
for key, value in obj._dict.items():
152+
unused_keys += [
153+
f"{key}.{subkey}" for subkey in self._unused_keys(value)
154+
]
155+
elif isinstance(obj, list):
156+
for i, value in enumerate(obj):
157+
unused_keys += [f"{i}.{subkey}" for subkey in self._unused_keys(value)]
158+
return unused_keys
159+
160+
def freeze(self):
161+
"""Freeze the ClassyMap to disallow mutations"""
162+
self._freeze(self)
163+
164+
def reset_tracking(self):
165+
"""Reset key tracking"""
166+
self._reset_tracking(self)
167+
168+
def unused_keys(self):
169+
"""Fetch all the unused keys"""
170+
return self._unused_keys(self)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import List
7+
8+
9+
class ConfigError(Exception):
10+
pass
11+
12+
13+
class ConfigUnusedKeysError(ConfigError):
14+
def __init__(self, unused_keys: List[str]):
15+
self.unused_keys = unused_keys
16+
super().__init__(f"The following keys were unused: {self.unused_keys}")

classy_vision/optim/sgd.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict
88

99
import torch.optim
10+
from classy_vision.configuration import ClassyMap
1011

1112
from . import ClassyOptimizer, register_optimizer
1213

@@ -63,10 +64,11 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD":
6364
config.setdefault("weight_decay", 0.0)
6465
config.setdefault("nesterov", False)
6566
config.setdefault("use_larc", False)
66-
config.setdefault(
67-
"larc_config", {"clip": True, "eps": 1e-08, "trust_coefficient": 0.02}
68-
)
69-
67+
if config["use_larc"]:
68+
larc_config = ClassyMap(clip=True, eps=1e-8, trust_coefficient=0.02)
69+
else:
70+
larc_config = None
71+
config.setdefault("larc_config", larc_config)
7072
assert (
7173
config["momentum"] >= 0.0
7274
and config["momentum"] < 1.0

classy_vision/tasks/classification_task.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818
import torch.nn as nn
19+
from classy_vision.configuration import ClassyMap, ConfigUnusedKeysError
1920
from classy_vision.dataset import ClassyDataset, build_dataset
2021
from classy_vision.dataset.transforms.mixup import MixupTransform
2122
from classy_vision.generic.distributed_util import (
@@ -456,6 +457,11 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
456457
Returns:
457458
A ClassificationTask instance.
458459
"""
460+
orig_config = config
461+
config = ClassyMap(orig_config)
462+
# access the name key to make sure it gets tracked
463+
config["name"]
464+
459465
test_only = config.get("test_only", False)
460466
if not test_only:
461467
# TODO Make distinction between epochs and phases in optimizer clear
@@ -537,9 +543,13 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
537543
for phase_type in datasets:
538544
task.set_dataset(datasets[phase_type], phase_type)
539545

546+
# at this stage all the configs keys should have been used
547+
if config.unused_keys():
548+
raise ConfigUnusedKeysError(config.unused_keys())
549+
540550
# NOTE: this is a private member and only meant to be used for
541551
# logging/debugging purposes. See __repr__ implementation
542-
task._config = config
552+
task._config = orig_config
543553

544554
return task
545555

test/configuration_classy_map_test.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
import unittest
8+
9+
from classy_vision.configuration import ClassyMap
10+
11+
12+
class ClassyMapTest(unittest.TestCase):
13+
def test_dict(self):
14+
d = ClassyMap(a=1, b=[1, 2, "3"])
15+
d["c"] = [4]
16+
d["d"] = {"a": 2}
17+
self.assertEqual(d, {"a": 1, "b": [1, 2, "3"], "c": [4], "d": {"a": 2}})
18+
self.assertIsInstance(d, ClassyMap)
19+
self.assertIsInstance(d["d"], ClassyMap)
20+
21+
def test_freezing(self):
22+
d = ClassyMap(a=1, b=2)
23+
d.freeze()
24+
# resetting an already existing key
25+
with self.assertRaises(TypeError):
26+
d["a"] = 3
27+
# adding a new key
28+
with self.assertRaises(TypeError):
29+
d["f"] = 3
30+
31+
def test_unused_keys(self):
32+
d = ClassyMap(
33+
a=1,
34+
b=[
35+
1,
36+
2,
37+
{
38+
"c": {"a": 2},
39+
"d": 4,
40+
"e": {"a": 1, "b": 2},
41+
"f": {"a": 1, "b": {"c": 2}},
42+
},
43+
],
44+
)
45+
46+
all_keys = {
47+
"a",
48+
"b",
49+
"b.2.c",
50+
"b.2.c.a",
51+
"b.2.d",
52+
"b.2.e",
53+
"b.2.f",
54+
"b.2.e.a",
55+
"b.2.e.b",
56+
"b.2.f.a",
57+
"b.2.f.b",
58+
"b.2.f.b.c",
59+
}
60+
61+
def test_func(**kwargs):
62+
return None
63+
64+
for _ in range(2):
65+
expected_unused_keys = all_keys.copy()
66+
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)
67+
68+
_ = d["a"]
69+
expected_unused_keys.remove("a")
70+
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)
71+
72+
_ = d["b"][2].get("d")
73+
expected_unused_keys.remove("b")
74+
expected_unused_keys.remove("b.2.d")
75+
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)
76+
77+
_ = d["b"][2]["e"]
78+
expected_unused_keys.remove("b.2.e")
79+
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)
80+
81+
_ = d["b"][2]["e"].items()
82+
expected_unused_keys.remove("b.2.e.a")
83+
expected_unused_keys.remove("b.2.e.b")
84+
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)
85+
86+
_ = d["b"][2]["f"]
87+
expected_unused_keys.remove("b.2.f")
88+
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)
89+
90+
test_func(**d["b"][2]["f"])
91+
expected_unused_keys.remove("b.2.f.a")
92+
expected_unused_keys.remove("b.2.f.b")
93+
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)
94+
95+
_ = copy.deepcopy(d)
96+
expected_unused_keys.remove("b.2.c")
97+
expected_unused_keys.remove("b.2.c.a")
98+
expected_unused_keys.remove("b.2.f.b.c")
99+
self.assertSetEqual(set(d.unused_keys()), expected_unused_keys)
100+
101+
d.reset_tracking()

0 commit comments

Comments
 (0)