@@ -83,6 +83,14 @@ class RunOptions:
83
83
pooling_factors (Optional[List[float]]): Pooling factors for each feature of the table.
84
84
This is the average number of values each sample has for the feature.
85
85
num_poolings (Optional[List[float]]): Number of poolings for each feature of the table.
86
+ dense_optimizer (str): Optimizer to use for dense parameters.
87
+ Default is "SGD".
88
+ dense_lr (float): Learning rate for dense parameters.
89
+ Default is 0.1.
90
+ sparse_optimizer (str): Optimizer to use for sparse parameters.
91
+ Default is "EXACT_ADAGRAD".
92
+ sparse_lr (float): Learning rate for sparse parameters.
93
+ Default is 0.1.
86
94
"""
87
95
88
96
world_size : int = 2
@@ -94,6 +102,14 @@ class RunOptions:
94
102
planner_type : str = "embedding"
95
103
pooling_factors : Optional [List [float ]] = None
96
104
num_poolings : Optional [List [float ]] = None
105
+ dense_optimizer : str = "SGD"
106
+ dense_lr : float = 0.1
107
+ dense_momentum : Optional [float ] = None
108
+ dense_weight_decay : Optional [float ] = None
109
+ sparse_optimizer : str = "EXACT_ADAGRAD"
110
+ sparse_lr : float = 0.1
111
+ sparse_momentum : Optional [float ] = None
112
+ sparse_weight_decay : Optional [float ] = None
97
113
98
114
99
115
@dataclass
@@ -286,17 +302,31 @@ def runner(
286
302
num_batches = run_option .num_batches ,
287
303
)
288
304
305
+ # Prepare fused_params for sparse optimizer
306
+ fused_params = {
307
+ "optimizer" : getattr (EmbOptimType , run_option .sparse_optimizer .upper ()),
308
+ "learning_rate" : run_option .sparse_lr ,
309
+ }
310
+
311
+ # Add momentum and weight_decay to fused_params if provided
312
+ if run_option .sparse_momentum is not None :
313
+ fused_params ["momentum" ] = run_option .sparse_momentum
314
+
315
+ if run_option .sparse_weight_decay is not None :
316
+ fused_params ["weight_decay" ] = run_option .sparse_weight_decay
317
+
289
318
sharded_model , optimizer = generate_sharded_model_and_optimizer (
290
319
model = unsharded_model ,
291
320
sharding_type = run_option .sharding_type .value ,
292
321
kernel_type = run_option .compute_kernel .value ,
293
322
# pyre-ignore
294
323
pg = ctx .pg ,
295
324
device = ctx .device ,
296
- fused_params = {
297
- "optimizer" : EmbOptimType .EXACT_ADAGRAD ,
298
- "learning_rate" : 0.1 ,
299
- },
325
+ fused_params = fused_params ,
326
+ dense_optimizer = run_option .dense_optimizer ,
327
+ dense_lr = run_option .dense_lr ,
328
+ dense_momentum = run_option .dense_momentum ,
329
+ dense_weight_decay = run_option .dense_weight_decay ,
300
330
planner = planner ,
301
331
)
302
332
pipeline = generate_pipeline (
0 commit comments