Skip to content

Commit 5025063

Browse files
authored
Re-support kwargs at run time (#929)
## Description Implements #928 Users want the first pipeline stage to accept kwargs if the original program does. This is controlled by the `_codegen` field of the graph as @angelayi suggests, so we make a copy from the traced program to submod0. ## Feature/Issue validation/testing Added kwargs in test_fwd.py. Also changed a few HF examples to directly kwargs.
1 parent e9e2d5f commit 5025063

File tree

10 files changed

+91
-48
lines changed

10 files changed

+91
-48
lines changed

examples/hf/pippy_albert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def run(args):
4646
# Input configs
4747
example_inputs = generate_inputs_for_model(
4848
model_class, albert, model_name, args.batch_size, args.device)
49-
input_ids = example_inputs["input_ids"]
5049

5150
# Annotate split points
5251
add_split_points(albert, args.world_size)
@@ -55,7 +54,8 @@ def run(args):
5554
albert_pipe = Pipe.from_tracing(
5655
albert,
5756
num_chunks=args.chunks,
58-
example_args=(input_ids, ),
57+
example_args=(),
58+
example_kwargs=example_inputs,
5959
)
6060
nstages = len(list(albert_pipe.split_gm.children()))
6161
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
@@ -72,7 +72,7 @@ def run(args):
7272

7373
# Run
7474
if args.rank == 0:
75-
stage(input_ids)
75+
stage(**example_inputs)
7676
elif args.rank == args.world_size - 1:
7777
out = stage()
7878
else:

examples/hf/pippy_bart.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def run(args):
4343
# Input configs
4444
example_inputs = generate_inputs_for_model(
4545
model_class, bart, model_name, args.batch_size, args.device)
46-
input_ids = example_inputs["input_ids"]
4746

4847
# Annotate split points
4948
add_split_points(bart, args.world_size)
@@ -52,7 +51,8 @@ def run(args):
5251
bart_pipe = Pipe.from_tracing(
5352
bart,
5453
num_chunks=args.chunks,
55-
example_args=(input_ids, ),
54+
example_args=(),
55+
example_kwargs=example_inputs,
5656
)
5757
nstages = len(list(bart_pipe.split_gm.children()))
5858
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
@@ -69,7 +69,7 @@ def run(args):
6969

7070
# Run
7171
if args.rank == 0:
72-
stage(input_ids)
72+
stage(**example_inputs)
7373
elif args.rank == args.world_size - 1:
7474
out = stage()
7575
else:

examples/hf/pippy_bert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def run(args):
4343
# Input configs
4444
example_inputs = generate_inputs_for_model(
4545
model_class, bert, model_name, args.batch_size, args.device)
46-
input_ids = example_inputs["input_ids"]
4746

4847
# Annotate split points
4948
add_split_points(bert, args.world_size)
@@ -52,7 +51,8 @@ def run(args):
5251
bert_pipe = Pipe.from_tracing(
5352
bert,
5453
num_chunks=args.chunks,
55-
example_args=(input_ids, ),
54+
example_args=(),
55+
example_kwargs=example_inputs,
5656
)
5757
nstages = len(list(bert_pipe.split_gm.children()))
5858
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
@@ -69,7 +69,7 @@ def run(args):
6969

7070
# Run
7171
if args.rank == 0:
72-
stage(input_ids)
72+
stage(**example_inputs)
7373
elif args.rank == args.world_size - 1:
7474
out = stage()
7575
else:

examples/hf/pippy_camemBert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def run(args):
4343
# Input configs
4444
example_inputs = generate_inputs_for_model(
4545
model_class, camembert, model_name, args.batch_size, args.device)
46-
input_ids = example_inputs["input_ids"]
4746

4847
# Annotate split points
4948
add_split_points(camembert, args.world_size)
@@ -52,7 +51,8 @@ def run(args):
5251
camembert_pipe = Pipe.from_tracing(
5352
camembert,
5453
num_chunks=args.chunks,
55-
example_args=(input_ids, ),
54+
example_args=(),
55+
example_kwargs=example_inputs,
5656
)
5757
nstages = len(list(camembert_pipe.split_gm.children()))
5858
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
@@ -69,7 +69,7 @@ def run(args):
6969

7070
# Run
7171
if args.rank == 0:
72-
stage(input_ids)
72+
stage(**example_inputs)
7373
elif args.rank == args.world_size - 1:
7474
out = stage()
7575
else:

examples/hf/pippy_gpt2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def run(args):
5252
# Input configs
5353
example_inputs = generate_inputs_for_model(
5454
model_class, gpt2, model_name, args.batch_size, args.device)
55-
input_ids = example_inputs["input_ids"]
5655

5756
# Annotate split points
5857
add_split_points(gpt2, args.world_size)
@@ -61,7 +60,8 @@ def run(args):
6160
gpt2_pipe = Pipe.from_tracing(
6261
gpt2,
6362
num_chunks=args.chunks,
64-
example_args=(input_ids, ),
63+
example_args=(),
64+
example_kwargs=example_inputs,
6565
)
6666
assert len(list(gpt2_pipe.split_gm.children())) == args.world_size
6767
if args.rank == 0:
@@ -77,7 +77,7 @@ def run(args):
7777

7878
# Run
7979
if args.rank == 0:
80-
stage(input_ids)
80+
stage(**example_inputs)
8181
elif args.rank == args.world_size - 1:
8282
out = stage()
8383
else:

examples/hf/pippy_gptNeo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def run(args):
4343
# Input configs
4444
example_inputs = generate_inputs_for_model(
4545
model_class, gptneo, model_name, args.batch_size, args.device)
46-
input_ids = example_inputs["input_ids"]
4746

4847
# Annotate split points
4948
add_split_points(gptneo, args.world_size)
@@ -52,7 +51,8 @@ def run(args):
5251
gptneo_pipe = Pipe.from_tracing(
5352
gptneo,
5453
num_chunks=args.chunks,
55-
example_args=(input_ids, ),
54+
example_args=(),
55+
example_kwargs=example_inputs,
5656
)
5757
nstages = len(list(gptneo_pipe.split_gm.children()))
5858
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
@@ -69,7 +69,7 @@ def run(args):
6969

7070
# Run
7171
if args.rank == 0:
72-
stage(input_ids)
72+
stage(**example_inputs)
7373
elif args.rank == args.world_size - 1:
7474
out = stage()
7575
else:

examples/hf/pippy_opt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def run(args):
4343
# Input configs
4444
example_inputs = generate_inputs_for_model(
4545
model_class, opt, model_name, args.batch_size, args.device)
46-
input_ids = example_inputs["input_ids"]
4746

4847
# Annotate split points
4948
add_split_points(opt, args.world_size)
@@ -52,7 +51,8 @@ def run(args):
5251
opt_pipe = Pipe.from_tracing(
5352
opt,
5453
num_chunks=args.chunks,
55-
example_args=(input_ids, ),
54+
example_args=(),
55+
example_kwargs=example_inputs,
5656
)
5757
nstages = len(list(opt_pipe.split_gm.children()))
5858
assert nstages == args.world_size, f"nstages = {nstages} nranks = {args.world_size}"
@@ -69,7 +69,7 @@ def run(args):
6969

7070
# Run
7171
if args.rank == 0:
72-
stage(input_ids)
72+
stage(**example_inputs)
7373
elif args.rank == args.world_size - 1:
7474
out = stage()
7575
else:

pippy/IR.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import operator
55
from enum import Enum
6+
from inspect import Parameter, signature, Signature
67
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
78

89
import torch
@@ -655,8 +656,6 @@ def throw(self, *args, **kwargs):
655656
def forward(self, *args, **kwargs):
656657
executor_args = args
657658
if len(kwargs) > 0:
658-
from inspect import Parameter, Signature
659-
660659
parameters = []
661660
for node in self.split_gm.graph.nodes:
662661
if node.op == "placeholder":
@@ -1005,6 +1004,34 @@ def move_param_to_callee(
10051004

10061005
split.delete_all_unused_submodules()
10071006

1007+
# Users want the first pipeline stage to accept kwargs if the original
1008+
# program does. This is controlled by the `_codegen` field of the graph,
1009+
# so we make a copy here. Note: we only want the input spec and not the
1010+
# output spec, because the output spec is for the last stage. Maybe a
1011+
# TODO? Not sure yet.
1012+
submod0 = list(split.children())[0]
1013+
model_sign = signature(traced.forward)
1014+
model_num_args = len(model_sign.parameters)
1015+
submod0_sign = signature(submod0.forward)
1016+
submod0_num_args = len(submod0_sign.parameters)
1017+
if model_num_args != submod0_num_args:
1018+
# We don't change the signature of the first stage if it takes
1019+
# different number of args than original model
1020+
logger.info(
1021+
f"Original model takes {model_num_args} args but the first pipeline stage takes {submod0_num_args}. "
1022+
"Please provide args to respective pipeline stages."
1023+
)
1024+
else:
1025+
# Support kwargs for the first stage
1026+
submod0.graph._codegen = copy.deepcopy(traced.graph._codegen)
1027+
# `_replace` is actually not "private" or internal. based on this doc:
1028+
# To prevent conflicts with field names, the method and attribute names
1029+
# start with an underscore
1030+
submod0.graph._codegen.pytree_info = (
1031+
submod0.graph._codegen.pytree_info._replace(out_spec=None)
1032+
)
1033+
submod0.recompile()
1034+
10081035
split.graph.lint()
10091036
split.recompile()
10101037

@@ -1071,6 +1098,7 @@ def _trace_with_export(
10711098
example_kwargs,
10721099
constraints,
10731100
)
1101+
logger.debug(f"Traced model: {traced}")
10741102
if split_policy is not None:
10751103
traced = split_policy(traced)
10761104
finally:

pippy/PipelineStage.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.distributed as dist
88
import torch.fx as fx
99
from torch._subclasses.fake_tensor import FakeTensor
10+
from torch.fx.node import map_aggregate, map_arg
1011
from torch.nn.parallel import DistributedDataParallel
1112

1213
from pippy.backward import stage_backward
@@ -47,6 +48,10 @@ class StageArgPlaceholder:
4748
pass
4849

4950

51+
class StageKwargPlaceholder:
52+
pass
53+
54+
5055
class PipelineStage(torch.nn.Module):
5156
def __init__(
5257
self,
@@ -269,14 +274,15 @@ def create_recv_tensor(
269274

270275
# `args` is a Tuple, hence we will have:
271276
# Tuple[RecvInfo]
272-
args_recv_info = fx.node.map_arg(self.node.args, create_recv_tensor)
277+
args_recv_info = map_arg(self.node.args, create_recv_tensor)
273278

274279
# `kwargs` is a Dict, hence we will have:
275280
# Dict[keyword, RecvInfo]
276-
kwargs_recv_info = fx.node.map_arg(self.node.kwargs, create_recv_tensor)
281+
kwargs_recv_info = map_arg(self.node.kwargs, create_recv_tensor)
277282

278283
logger.info(
279-
f"[{self.group_rank}] " f"Activation recv info: {args_recv_info}"
284+
f"[{self.group_rank}] "
285+
f"Activation recv / args info: {args_recv_info}"
280286
)
281287
return args_recv_info, kwargs_recv_info
282288

@@ -370,9 +376,9 @@ def map_recv_to_send(a):
370376
grad_send_info.append(None)
371377
return None
372378

373-
fx.node.map_aggregate(args_recv_info, map_recv_to_send)
379+
map_aggregate(args_recv_info, map_recv_to_send)
374380

375-
fx.node.map_aggregate(kwargs_recv_info, map_recv_to_send)
381+
map_aggregate(kwargs_recv_info, map_recv_to_send)
376382

377383
logger.info(f"[{self.group_rank}] " f"Grad send info: {grad_send_info}")
378384
return grad_send_info
@@ -422,35 +428,43 @@ def _recv_and_fill_inputs(
422428

423429
act_recv = self.recv_tensor_fn(recv_reqs)
424430

431+
chunk_args_list: List = []
425432
if self.args_split:
426433
chunk_args = self.args_split[chunk]
427434
chunk_args_list = list(chunk_args)
428435

429436
def recv_args(info):
430437
if isinstance(info, RecvInfo):
438+
# This is an activation to receive
431439
return act_recv(info)
432440
else:
433-
return chunk_args_list.pop(0) # type: ignore[has-type]
441+
# This is a pass-in argument
442+
if len(chunk_args_list):
443+
return chunk_args_list.pop(0) # type: ignore[has-type]
444+
else:
445+
# kwargs were treated as args in graph phase. That's why
446+
# there are extra placeholders here. We mark them and filter
447+
# them out later.
448+
return StageKwargPlaceholder()
434449

435-
composite_args = fx.node.map_aggregate(
450+
composite_args = map_aggregate(
436451
self.args_recv_info[chunk],
437452
recv_args,
438453
)
454+
# Filter out kwarg placeholders
455+
composite_args = tuple(
456+
x
457+
for x in composite_args
458+
if not isinstance(x, StageKwargPlaceholder)
459+
)
439460

461+
# Middle stages won't have incoming activations in kwargs form. So if
462+
# kwargs_split is not empty, it must be model inputs for stage 0. We
463+
# hence pass it as is to the interal submodule, without performing
464+
# `recv_args` on it.
465+
composite_kwargs: Dict = {}
440466
if self.kwargs_split:
441-
chunk_kwargs = self.kwargs_split[chunk]
442-
443-
def recv_kwargs(info):
444-
if isinstance(info, RecvInfo):
445-
return act_recv(info)
446-
else:
447-
k = next(iter(chunk_kwargs)) # type: ignore[has-type]
448-
return chunk_kwargs.pop(k) # type: ignore[has-type]
449-
450-
composite_kwargs = fx.node.map_aggregate(
451-
self.kwargs_recv_info[chunk],
452-
recv_kwargs,
453-
)
467+
composite_kwargs = self.kwargs_split[chunk]
454468

455469
# Wait for all recvs to finish
456470
for work in recv_reqs:
@@ -496,7 +510,7 @@ def _recv_grads(
496510
recv_grad = self.recv_tensor_fn(grad_recv_reqs)
497511

498512
# Receive gradients
499-
grads = fx.node.map_aggregate(
513+
grads = map_aggregate(
500514
self.grad_recv_info[bwd_chunk],
501515
recv_grad,
502516
)

0 commit comments

Comments
 (0)