Skip to content

Commit ef2f7de

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Added Optimizer configuration that supports optimizer type, learning rate, momentum, and weight decay configurations. (#3107)
Summary: This commit introduces enhancements to the optimizer configuration in TorchRec. It now supports specifying the optimizer type, learning rate, momentum, and weight decay. Differential Revision: D76837924
1 parent 4225394 commit ef2f7de

File tree

2 files changed

+62
-16
lines changed

2 files changed

+62
-16
lines changed

torchrec/distributed/benchmark/benchmark_pipeline_utils.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -432,21 +432,23 @@ def generate_sharded_model_and_optimizer(
432432
kernel_type: str,
433433
pg: dist.ProcessGroup,
434434
device: torch.device,
435-
fused_params: Optional[Dict[str, Any]] = None,
435+
fused_params: Dict[str, Any],
436+
dense_optimizer: str,
437+
dense_lr: float,
438+
dense_momentum: Optional[float],
439+
dense_weight_decay: Optional[float],
436440
planner: Optional[
437441
Union[
438442
EmbeddingShardingPlanner,
439443
HeteroEmbeddingShardingPlanner,
440444
]
441445
] = None,
442446
) -> Tuple[nn.Module, Optimizer]:
443-
# Ensure fused_params is always a dictionary
444-
fused_params_dict = {} if fused_params is None else fused_params
445447

446448
sharder = TestEBCSharder(
447449
sharding_type=sharding_type,
448450
kernel_type=kernel_type,
449-
fused_params=fused_params_dict,
451+
fused_params=fused_params,
450452
)
451453
sharders = [cast(ModuleSharder[nn.Module], sharder)]
452454

@@ -466,14 +468,28 @@ def generate_sharded_model_and_optimizer(
466468
sharders=sharders,
467469
plan=plan,
468470
).to(device)
469-
optimizer = optim.SGD(
470-
[
471-
param
472-
for name, param in sharded_model.named_parameters()
473-
if "sparse" not in name
474-
],
475-
lr=0.1,
476-
)
471+
472+
# Get dense parameters
473+
dense_params = [
474+
param
475+
for name, param in sharded_model.named_parameters()
476+
if "sparse" not in name
477+
]
478+
479+
# Create optimizer based on the specified type
480+
optimizer_class = getattr(optim, dense_optimizer)
481+
482+
# Create optimizer with momentum and/or weight_decay if provided
483+
optimizer_kwargs = {"lr": dense_lr}
484+
485+
if dense_momentum is not None:
486+
optimizer_kwargs["momentum"] = dense_momentum
487+
488+
if dense_weight_decay is not None:
489+
optimizer_kwargs["weight_decay"] = dense_weight_decay
490+
491+
optimizer = optimizer_class(dense_params, **optimizer_kwargs)
492+
477493
return sharded_model, optimizer
478494

479495

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ class RunOptions:
8383
pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table.
8484
This is the average number of values each sample has for the feature.
8585
num_poolings (Optional[List[float]]): Number of poolings for each feature of the table.
86+
dense_optimizer (str): Optimizer to use for dense parameters.
87+
Default is "SGD".
88+
dense_lr (float): Learning rate for dense parameters.
89+
Default is 0.1.
90+
sparse_optimizer (str): Optimizer to use for sparse parameters.
91+
Default is "EXACT_ADAGRAD".
92+
sparse_lr (float): Learning rate for sparse parameters.
93+
Default is 0.1.
8694
"""
8795

8896
world_size: int = 2
@@ -94,6 +102,14 @@ class RunOptions:
94102
planner_type: str = "embedding"
95103
pooling_factors: Optional[List[float]] = None
96104
num_poolings: Optional[List[float]] = None
105+
dense_optimizer: str = "SGD"
106+
dense_lr: float = 0.1
107+
dense_momentum: Optional[float] = None
108+
dense_weight_decay: Optional[float] = None
109+
sparse_optimizer: str = "EXACT_ADAGRAD"
110+
sparse_lr: float = 0.1
111+
sparse_momentum: Optional[float] = None
112+
sparse_weight_decay: Optional[float] = None
97113

98114

99115
@dataclass
@@ -286,17 +302,31 @@ def runner(
286302
num_batches=run_option.num_batches,
287303
)
288304

305+
# Prepare fused_params for sparse optimizer
306+
fused_params = {
307+
"optimizer": getattr(EmbOptimType, run_option.sparse_optimizer.upper()),
308+
"learning_rate": run_option.sparse_lr,
309+
}
310+
311+
# Add momentum and weight_decay to fused_params if provided
312+
if run_option.sparse_momentum is not None:
313+
fused_params["momentum"] = run_option.sparse_momentum
314+
315+
if run_option.sparse_weight_decay is not None:
316+
fused_params["weight_decay"] = run_option.sparse_weight_decay
317+
289318
sharded_model, optimizer = generate_sharded_model_and_optimizer(
290319
model=unsharded_model,
291320
sharding_type=run_option.sharding_type.value,
292321
kernel_type=run_option.compute_kernel.value,
293322
# pyre-ignore
294323
pg=ctx.pg,
295324
device=ctx.device,
296-
fused_params={
297-
"optimizer": EmbOptimType.EXACT_ADAGRAD,
298-
"learning_rate": 0.1,
299-
},
325+
fused_params=fused_params,
326+
dense_optimizer=run_option.dense_optimizer,
327+
dense_lr=run_option.dense_lr,
328+
dense_momentum=run_option.dense_momentum,
329+
dense_weight_decay=run_option.dense_weight_decay,
300330
planner=planner,
301331
)
302332
pipeline = generate_pipeline(

0 commit comments

Comments
 (0)