Skip to content

Commit cd1d480

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
An alternative implementation of torchrec serializer (#2166)
Summary: Pull Request resolved: #2166 # context * after discussion with dstaay-fb, a rule of thumb: serializer APIs better to be symmetric * `SerializerInterface.serialize` takes a target module (nn.Module, usually sparse), returns a tensor (serialized binary, will be put to a buffer), and a list of child_fqns, which require further serialization * `SerializerInterface.deserialize` takes the binary data (tensor from buffer), a device flag, and the unflattened module (for its child modules), returns the deserialized module * the main APIs for external usage are the `serialize_embedding_modules` and `_deserialize_embedding_modules` * the former walks through the input module, finds the sparse (sub)modules, stores sparse modules metadata in buffer * the later walks through the input module (unflattened from ep), finds the stored metadata from the buffer, and restore the sparse modules. Differential Revision: D58933792
1 parent 03c3a72 commit cd1d480

File tree

4 files changed

+144
-104
lines changed

4 files changed

+144
-104
lines changed

torchrec/ir/serializer.py

Lines changed: 48 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import json
1111
import logging
12-
from typing import Any, Dict, Optional, Type
12+
from typing import Any, Dict, List, Optional, Tuple, Type
1313

1414
import torch
1515

@@ -69,8 +69,18 @@ def get_deserialized_device(
6969
return device
7070

7171

72-
class JsonSerializerBase(SerializerInterface):
72+
class JsonSerializer(SerializerInterface):
73+
"""
74+
Serializer for torch.export IR using json.
75+
"""
76+
77+
module_to_serializer_cls: Dict[str, Type["JsonSerializer"]] = {}
7378
_module_cls: Optional[Type[nn.Module]] = None
79+
_children: Optional[List[str]] = None
80+
81+
@classmethod
82+
def children(cls, module: nn.Module) -> List[str]:
83+
return [] if not cls._children else cls._children
7484

7585
@classmethod
7686
def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]:
@@ -81,47 +91,67 @@ def deserialize_from_dict(
8191
cls,
8292
metadata_dict: Dict[str, Any],
8393
device: Optional[torch.device] = None,
94+
unflatten: Optional[nn.Module] = None,
8495
) -> nn.Module:
8596
raise NotImplementedError()
8697

8798
@classmethod
8899
def serialize(
89100
cls,
90101
module: nn.Module,
91-
) -> torch.Tensor:
92-
if cls._module_cls is None:
102+
) -> Tuple[torch.Tensor, List[str]]:
103+
typename = type(module).__name__
104+
serializer = cls.module_to_serializer_cls.get(typename)
105+
if serializer is None:
93106
raise ValueError(
94-
"Must assign a nn.Module to class static variable _module_cls"
107+
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
95108
)
96-
if not isinstance(module, cls._module_cls):
109+
assert issubclass(serializer, JsonSerializer)
110+
assert serializer._module_cls is not None
111+
if not isinstance(module, serializer._module_cls):
97112
raise ValueError(
98-
f"Expected module to be of type {cls._module_cls.__name__}, got {type(module)}"
113+
f"Expected module to be of type {serializer._module_cls.__name__}, "
114+
f"got {type(module)}"
99115
)
100-
metadata_dict = cls.serialize_to_dict(module)
101-
return torch.frombuffer(json.dumps(metadata_dict).encode(), dtype=torch.uint8)
116+
metadata_dict = serializer.serialize_to_dict(module)
117+
raw_dict = {"typename": typename, "metadata_dict": metadata_dict}
118+
serialized_tensor = torch.frombuffer(
119+
json.dumps(raw_dict).encode(), dtype=torch.uint8
120+
)
121+
return serialized_tensor, serializer.children(module)
102122

103123
@classmethod
104124
def deserialize(
105125
cls,
106126
input: torch.Tensor,
107-
typename: str,
108127
device: Optional[torch.device] = None,
128+
unflatten: Optional[nn.Module] = None,
109129
) -> nn.Module:
110130
raw_bytes = input.numpy().tobytes()
111-
metadata_dict = json.loads(raw_bytes.decode())
112-
module = cls.deserialize_from_dict(metadata_dict, device)
113-
if cls._module_cls is None:
131+
raw_dict = json.loads(raw_bytes.decode())
132+
typename = raw_dict["typename"]
133+
if typename not in cls.module_to_serializer_cls:
134+
raise ValueError(
135+
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
136+
)
137+
serializer = cls.module_to_serializer_cls[typename]
138+
assert issubclass(serializer, JsonSerializer)
139+
module = serializer.deserialize_from_dict(
140+
raw_dict["metadata_dict"], device, unflatten
141+
)
142+
143+
if serializer._module_cls is None:
114144
raise ValueError(
115145
"Must assign a nn.Module to class static variable _module_cls"
116146
)
117-
if not isinstance(module, cls._module_cls):
147+
if not isinstance(module, serializer._module_cls):
118148
raise ValueError(
119-
f"Expected module to be of type {cls._module_cls.__name__}, got {type(module)}"
149+
f"Expected module to be of type {serializer._module_cls.__name__}, got {type(module)}"
120150
)
121151
return module
122152

123153

124-
class EBCJsonSerializer(JsonSerializerBase):
154+
class EBCJsonSerializer(JsonSerializer):
125155
_module_cls = EmbeddingBagCollection
126156

127157
@classmethod
@@ -148,6 +178,7 @@ def deserialize_from_dict(
148178
cls,
149179
metadata_dict: Dict[str, Any],
150180
device: Optional[torch.device] = None,
181+
unflatten: Optional[nn.Module] = None,
151182
) -> nn.Module:
152183
tables = [
153184
EmbeddingBagConfigMetadata(**table_config)
@@ -164,40 +195,4 @@ def deserialize_from_dict(
164195
)
165196

166197

167-
class JsonSerializer(SerializerInterface):
168-
"""
169-
Serializer for torch.export IR using json.
170-
"""
171-
172-
module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = {
173-
"EmbeddingBagCollection": EBCJsonSerializer,
174-
}
175-
176-
@classmethod
177-
def serialize(
178-
cls,
179-
module: nn.Module,
180-
) -> torch.Tensor:
181-
typename = type(module).__name__
182-
if typename not in cls.module_to_serializer_cls:
183-
raise ValueError(
184-
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
185-
)
186-
187-
return cls.module_to_serializer_cls[typename].serialize(module)
188-
189-
@classmethod
190-
def deserialize(
191-
cls,
192-
input: torch.Tensor,
193-
typename: str,
194-
device: Optional[torch.device] = None,
195-
) -> nn.Module:
196-
if typename not in cls.module_to_serializer_cls:
197-
raise ValueError(
198-
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
199-
)
200-
201-
return cls.module_to_serializer_cls[typename].deserialize(
202-
input, typename, device
203-
)
198+
JsonSerializer.module_to_serializer_cls["EmbeddingBagCollection"] = EBCJsonSerializer

torchrec/ir/tests/test_serializer.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import copy
1313
import unittest
14-
from typing import Callable, List, Optional, Union
14+
from typing import Any, Callable, Dict, List, Optional, Union
1515

1616
import torch
1717
from torch import nn
@@ -54,6 +54,41 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
5454
return res
5555

5656

57+
class CompoundModuleSerializer(JsonSerializer):
58+
_module_cls = CompoundModule
59+
60+
@classmethod
61+
def children(cls, module: nn.Module) -> List[str]:
62+
children = ["ebc", "list"]
63+
if module.comp is not None:
64+
children += ["comp"]
65+
return children
66+
67+
@classmethod
68+
def serialize_to_dict(
69+
cls,
70+
module: nn.Module,
71+
) -> Dict[str, Any]:
72+
return {}
73+
74+
@classmethod
75+
def deserialize_from_dict(
76+
cls,
77+
metadata_dict: Dict[str, Any],
78+
device: Optional[torch.device] = None,
79+
unflatten: Optional[nn.Module] = None,
80+
) -> nn.Module:
81+
assert unflatten is not None
82+
ebc = unflatten.ebc
83+
comp = getattr(unflatten, "comp", None)
84+
i = 0
85+
mlist = []
86+
while hasattr(unflatten.list, str(i)):
87+
mlist.append(getattr(unflatten.list, str(i)))
88+
i += 1
89+
return CompoundModule(ebc, comp, mlist)
90+
91+
5792
class TestJsonSerializer(unittest.TestCase):
5893
def generate_model(self) -> nn.Module:
5994
class Model(nn.Module):
@@ -328,6 +363,9 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
328363

329364
eager_out = model(id_list_features)
330365

366+
JsonSerializer.module_to_serializer_cls["CompoundModule"] = (
367+
CompoundModuleSerializer
368+
)
331369
# Serialize
332370
model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer)
333371
ep = torch.export.export(
@@ -346,6 +384,14 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
346384

347385
# Deserialize
348386
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)
387+
388+
# Check if Compound Module is deserialized correctly
389+
self.assertIsInstance(deserialized_model.comp, CompoundModule)
390+
self.assertIsInstance(deserialized_model.comp.comp, CompoundModule)
391+
self.assertIsInstance(deserialized_model.comp.comp.comp, CompoundModule)
392+
self.assertIsInstance(deserialized_model.comp.list[1], CompoundModule)
393+
self.assertIsInstance(deserialized_model.comp.list[1].comp, CompoundModule)
394+
349395
deserialized_model.load_state_dict(model.state_dict())
350396
# Run forward on deserialized model
351397
deserialized_out = deserialized_model(id_list_features)

torchrec/ir/types.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#!/usr/bin/env python3
1111

1212
import abc
13-
from typing import Any, Dict, Optional, Type
13+
from typing import Any, Dict, List, Optional, Tuple
1414

1515
import torch
1616

@@ -24,28 +24,25 @@ class SerializerInterface(abc.ABC):
2424

2525
@classmethod
2626
@property
27-
# pyre-ignore [3]: Returning `None` but type `Any` is specified.
28-
def module_to_serializer_cls(cls) -> Dict[str, Type[Any]]:
27+
def module_to_serializer_cls(cls) -> Dict[str, Any]:
2928
raise NotImplementedError
3029

3130
@classmethod
3231
@abc.abstractmethod
33-
# pyre-ignore [3]: Returning `None` but type `Any` is specified.
3432
def serialize(
3533
cls,
3634
module: nn.Module,
37-
) -> Any:
35+
) -> Tuple[torch.Tensor, List[str]]:
3836
# Take the eager embedding module and generate bytes in buffer
39-
pass
37+
raise NotImplementedError
4038

4139
@classmethod
4240
@abc.abstractmethod
4341
def deserialize(
4442
cls,
45-
# pyre-ignore [2]: Parameter `input` must have a type other than `Any`.
46-
input: Any,
47-
typename: str,
43+
input: torch.Tensor,
4844
device: Optional[torch.device] = None,
45+
unflatten: Optional[nn.Module] = None,
4946
) -> nn.Module:
5047
# Take the bytes in the buffer and regenerate the eager embedding module
51-
pass
48+
raise NotImplementedError

torchrec/ir/utils.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727

2828

2929
def serialize_embedding_modules(
30-
model: nn.Module,
30+
module: nn.Module,
3131
serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
32+
fqn: str = "",
3233
) -> Tuple[nn.Module, List[str]]:
3334
"""
3435
Takes all the modules that are of type `serializer_cls` and serializes them
@@ -37,13 +38,46 @@ def serialize_embedding_modules(
3738
Returns the modified module and the list of fqns that had the buffer added.
3839
"""
3940
preserve_fqns = []
40-
for fqn, module in model.named_modules():
41-
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
42-
serialized_module = serializer_cls.serialize(module)
43-
module.register_buffer("ir_metadata", serialized_module, persistent=False)
44-
preserve_fqns.append(fqn)
4541

46-
return model, preserve_fqns
42+
# handle current module
43+
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
44+
serialized_tensor, children = serializer_cls.serialize(module)
45+
module.register_buffer("ir_metadata", serialized_tensor, persistent=False)
46+
preserve_fqns.append(fqn)
47+
else:
48+
children = [child for child, _ in module.named_children()]
49+
50+
# handle child modules
51+
for child in children:
52+
submodule = module.get_submodule(child)
53+
child_fqn = f"{fqn}.{child}" if len(fqn) > 0 else child
54+
preserve_fqns.extend(
55+
serialize_embedding_modules(submodule, serializer_cls, child_fqn)[1]
56+
)
57+
return module, preserve_fqns
58+
59+
60+
def _deserialize_embedding_modules(
61+
module: nn.Module,
62+
serializer_cls: Type[SerializerInterface],
63+
device: Optional[torch.device] = None,
64+
) -> nn.Module:
65+
"""
66+
returns:
67+
1. the children of the parent_fqn Dict[relative_fqn -> module]
68+
2. the next node Optional[fqn], Optional[module], which is not a child of the parent_fqn
69+
"""
70+
71+
for child_fqn, child in module.named_children():
72+
child = _deserialize_embedding_modules(
73+
module=child, serializer_cls=serializer_cls, device=device
74+
)
75+
setattr(module, child_fqn, child)
76+
77+
if "ir_metadata" in dict(module.named_buffers()):
78+
serialized_tensor = module.get_buffer("ir_metadata")
79+
module = serializer_cls.deserialize(serialized_tensor, device, module)
80+
return module
4781

4882

4983
def deserialize_embedding_modules(
@@ -59,39 +93,7 @@ def deserialize_embedding_modules(
5993
Returns the unflattened ExportedProgram with the deserialized modules.
6094
"""
6195
model = torch.export.unflatten(ep)
62-
module_type_dict = {}
63-
for node in ep.graph.nodes:
64-
if "nn_module_stack" in node.meta:
65-
for fqn, type_name in node.meta["nn_module_stack"].values():
66-
# Only get the module type name, not the full type name
67-
module_type_dict[fqn] = type_name.split(".")[-1]
68-
69-
fqn_to_new_module = {}
70-
for fqn, module in model.named_modules():
71-
if "ir_metadata" in dict(module.named_buffers()):
72-
serialized_module = dict(module.named_buffers())["ir_metadata"]
73-
74-
if fqn not in module_type_dict:
75-
raise RuntimeError(
76-
f"Cannot find the type of module {fqn} in the exported program"
77-
)
78-
79-
deserialized_module = serializer_cls.deserialize(
80-
serialized_module,
81-
module_type_dict[fqn],
82-
device,
83-
)
84-
fqn_to_new_module[fqn] = deserialized_module
85-
86-
for fqn, new_module in fqn_to_new_module.items():
87-
# handle nested attribute like "x.y.z"
88-
attrs = fqn.split(".")
89-
parent = model
90-
for a in attrs[:-1]:
91-
parent = getattr(parent, a)
92-
setattr(parent, attrs[-1], new_module)
93-
94-
return model
96+
return _deserialize_embedding_modules(model, serializer_cls, device)
9597

9698

9799
def _get_dim(x: Union[DIM, str, None], s: str, max: Optional[int] = None) -> DIM:

0 commit comments

Comments
 (0)