Skip to content

Commit e51c8b9

Browse files
authored
Add back support for stage.remap_qualname() (#934)
## Description `stage.remap_qualname(key)` maps a stage's parameter name (`key`) back to the original model's parameter name. This now works: ``` # Stage module's state dict sd = stage.submod.state_dict() remapped_keys = [stage.remap_qualname(k) for k in sd.keys()] # Original model's state dict old_keys = mod.state_dict().keys() # Confirm they match assert all(rk in old_keys for rk in remapped_keys) ```
1 parent 3632106 commit e51c8b9

File tree

5 files changed

+74
-71
lines changed

5 files changed

+74
-71
lines changed

pippy/IR.py

Lines changed: 6 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pippy.backward import _null_coalesce_accumulate, stage_backward
3030
from pippy.debug import PIPPY_VERBOSITY
3131
from pippy.microbatch import LossReducer, split_args_kwargs_into_chunks
32+
from pippy.utils import QualnameMapMixin
3233

3334

3435
logger = logging.getLogger(__name__)
@@ -498,54 +499,6 @@ def _direct_serialization_reduce(self):
498499
)
499500

500501

501-
class QualnameMapMixin:
502-
"""
503-
A mixin class to provide qualname remap functionality for both Pipe object
504-
and submodules
505-
"""
506-
507-
def __init__(
508-
self,
509-
splitter_qualname_map: Dict[str, str] = None,
510-
tracer_qualname_map: Dict[str, str] = None,
511-
):
512-
self.new_to_old_qualname_mapping: Dict[str, str] = (
513-
splitter_qualname_map or {}
514-
)
515-
self.tracer_qualname_map = tracer_qualname_map
516-
517-
def remap_qualname(self, qualname: str):
518-
# TODO: annoying
519-
if qualname.startswith("split_gm."):
520-
qualname = qualname[len("split_gm.") :]
521-
522-
name_before_split = None
523-
if qualname in self.new_to_old_qualname_mapping:
524-
name_before_split = self.new_to_old_qualname_mapping[qualname]
525-
else:
526-
# The qualname map does not store recursive items, thus,
527-
# when passed a qualname with leaves, we need to perform longest prefix match
528-
# Split from the right, one each time
529-
split_names = qualname.rsplit(".", 1)
530-
leaf = split_names[-1]
531-
while len(split_names) > 1:
532-
prefix = split_names[0]
533-
if prefix in self.new_to_old_qualname_mapping:
534-
old_prefix = self.new_to_old_qualname_mapping[prefix]
535-
name_before_split = ".".join([old_prefix, leaf])
536-
break
537-
split_names = prefix.rsplit(".", 1)
538-
leaf = ".".join([split_names[-1], leaf])
539-
540-
if name_before_split is None:
541-
raise RuntimeError(f"Could not find mapping for {qualname}")
542-
543-
if self.tracer_qualname_map is not None:
544-
return self.tracer_qualname_map[name_before_split]
545-
else:
546-
return name_before_split
547-
548-
549502
class Pipe(QualnameMapMixin, torch.nn.Module):
550503
def __init__(
551504
self,
@@ -615,6 +568,10 @@ def __init__(
615568
)
616569

617570
# Create qualname mapping for each submodule
571+
# Dict looks like this:
572+
# {submod_name : Dict{old_qualname : new_qualname}}
573+
# We save this information here for use during pipeline stage creation.
574+
self.submod_qualname_mappings: Dict[str, Dict[str, str]] = {}
618575
for m_qualname, mod in self.split_gm.named_children():
619576
# "submod_x." prefix
620577
mod_prefix = m_qualname + "."
@@ -624,16 +581,7 @@ def __init__(
624581
# Remove prefix
625582
new_key = k[len(mod_prefix) :]
626583
mod_qualname_mapping.setdefault(new_key, v)
627-
# Add a remap mixin to submodule instance
628-
# TODO: this class change is commented out because it breaks
629-
# recompilation if we want to recompile mod after. For example, we
630-
# may recompile mod to modify the "device" kwarg of a `torch.ones`
631-
# node (trace on cpu/meta, run on cuda).
632-
# See: https://github.com/pytorch/vision/issues/5826
633-
# mod.__class__ = type(
634-
# "PipeStageModule", (QualnameMapMixin, mod.__class__), {}
635-
# )
636-
setattr(mod, "new_to_old_qualname_mapping", mod_qualname_mapping)
584+
self.submod_qualname_mappings[m_qualname] = mod_qualname_mapping
637585

638586
def throw(self, *args, **kwargs):
639587
raise RuntimeError(

pippy/PipelineStage.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pippy.debug import map_debug_info
1515
from pippy.IR import Pipe
1616
from pippy.microbatch import merge_chunks, split_args_kwargs_into_chunks
17-
from pippy.utils import flatten_args, modify_graph_op_device
17+
from pippy.utils import flatten_args, modify_graph_op_device, QualnameMapMixin
1818

1919

2020
logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ class StageKwargPlaceholder:
5252
pass
5353

5454

55-
class PipelineStage(torch.nn.Module):
55+
class PipelineStage(torch.nn.Module, QualnameMapMixin):
5656
def __init__(
5757
self,
5858
pipe: Pipe,
@@ -98,6 +98,13 @@ def __init__(
9898
f"{self.submod}"
9999
)
100100

101+
# Enable `remap_qualname` method
102+
QualnameMapMixin.__init__(
103+
self,
104+
pipe.submod_qualname_mappings[self.name],
105+
pipe.tracer_qualname_map,
106+
)
107+
101108
# Find my forward node in graph
102109
found_node = False
103110
for node in self.split_gm.graph.nodes:

pippy/utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates
22
import logging
3+
from typing import Dict
34

45
import torch
56
import torch.distributed as dist
@@ -92,3 +93,51 @@ def modify_graph_op_device(
9293

9394
if modified:
9495
gm.recompile()
96+
97+
98+
class QualnameMapMixin:
99+
"""
100+
A mixin class to provide qualname remap functionality for both Pipe object
101+
and submodules
102+
"""
103+
104+
def __init__(
105+
self,
106+
splitter_qualname_map: Dict[str, str] = None,
107+
tracer_qualname_map: Dict[str, str] = None,
108+
):
109+
self.new_to_old_qualname_mapping: Dict[str, str] = (
110+
splitter_qualname_map or {}
111+
)
112+
self.tracer_qualname_map = tracer_qualname_map
113+
114+
def remap_qualname(self, qualname: str):
115+
# TODO: annoying
116+
if qualname.startswith("split_gm."):
117+
qualname = qualname[len("split_gm.") :]
118+
119+
name_before_split = None
120+
if qualname in self.new_to_old_qualname_mapping:
121+
name_before_split = self.new_to_old_qualname_mapping[qualname]
122+
else:
123+
# The qualname map does not store recursive items, thus,
124+
# when passed a qualname with leaves, we need to perform longest prefix match
125+
# Split from the right, one each time
126+
split_names = qualname.rsplit(".", 1)
127+
leaf = split_names[-1]
128+
while len(split_names) > 1:
129+
prefix = split_names[0]
130+
if prefix in self.new_to_old_qualname_mapping:
131+
old_prefix = self.new_to_old_qualname_mapping[prefix]
132+
name_before_split = ".".join([old_prefix, leaf])
133+
break
134+
split_names = prefix.rsplit(".", 1)
135+
leaf = ".".join([split_names[-1], leaf])
136+
137+
if name_before_split is None:
138+
raise RuntimeError(f"Could not find mapping for {qualname}")
139+
140+
if self.tracer_qualname_map is not None:
141+
return self.tracer_qualname_map[name_before_split]
142+
else:
143+
return name_before_split

test/test_fwd.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ def run_worker(args):
8383
f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}"
8484
)
8585

86+
# Test qualname mapping
87+
sd = stage.submod.state_dict()
88+
print(f"Rank {args.rank} state dict keys: {sd.keys()}")
89+
remapped_keys = [stage.remap_qualname(k) for k in sd.keys()]
90+
print(f"Rank {args.rank} remapped keys: {remapped_keys}")
91+
# Confirm remapped keys are consistent with original model
92+
old_keys = mod.state_dict().keys()
93+
assert all(rk in old_keys for rk in remapped_keys)
94+
print(f"Qualname test passed")
95+
8696

8797
def main(args=None):
8898
parser = argparse.ArgumentParser()

test/test_ir.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -791,17 +791,6 @@ def test_remap_qualname(self):
791791
old_name in old_names
792792
), f"Remapped parameter {old_name} not found in {old_names}"
793793

794-
# Check qualname mapping for submodule
795-
# Not supported at the moment
796-
"""
797-
for _, stage_mod in ec_pipe.split_gm.named_children():
798-
for new_name, _ in stage_mod.named_parameters():
799-
old_name = stage_mod.remap_qualname(new_name)
800-
assert (
801-
old_name in old_names
802-
), f"Remapped parameter {old_name} not found in {old_names}"
803-
"""
804-
805794

806795
if __name__ == "__main__":
807796
unittest.main()

0 commit comments

Comments
 (0)