Skip to content

Commit df9d96f

Browse files
author
Orbax Authors
committed
Support the XLA GPU compilation flags in Orbax
PiperOrigin-RevId: 879684156
1 parent 519af35 commit df9d96f

File tree

3 files changed

+222
-7
lines changed

3 files changed

+222
-7
lines changed

export/orbax/export/modules/obm_module_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,36 @@ def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization):
384384
with self.subTest('test_weights_b_dtype'):
385385
self.assertEqual(module.model_params['b'].dtype, expected_dtype)
386386

387+
def test_obm_module_gpu_xla_flags_integration_stable(self):
388+
param_shape = (2, 5)
389+
param_dtype = jnp.dtype(jnp.float32)
390+
param_spec = jax.ShapeDtypeStruct(shape=param_shape, dtype=param_dtype)
391+
model_function_name = 'simple_add'
392+
393+
jax2obm_options = obm_configs.Jax2ObmOptions(
394+
checkpoint_path='checkpoint_path',
395+
native_serialization_platforms=('cuda',),
396+
xla_flags_per_platform={
397+
'cuda': ['--xla_gpu_enable_latency_hiding_scheduler=true']
398+
},
399+
)
400+
401+
orbax_model_module = obm_module.ObmModule(
402+
params=param_spec,
403+
apply_fn={model_function_name: simple_add},
404+
jax2obm_options=jax2obm_options,
405+
)
406+
407+
xla_compile_options_map = (
408+
orbax_model_module.xla_compile_options_per_platform
409+
)
410+
self.assertIsNotNone(xla_compile_options_map)
411+
build_options_cuda = xla_compile_options_map.map['cuda']
412+
self.assertIn(
413+
'xla_gpu_enable_latency_hiding_scheduler',
414+
build_options_cuda.env_option_overrides,
415+
)
416+
387417

388418
class GetSharedValueTest(parameterized.TestCase):
389419

model/orbax/experimental/model/core/python/compile_options_util.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from collections.abc import Mapping, Sequence
1818
import logging
19+
import re
1920

2021
from google.protobuf import descriptor
2122
import jax
@@ -92,12 +93,14 @@ def _generate_tpu_compilation_env(
9293
def _generate_compilation_options(
9394
compile_environment: xla_pb2.CompilationEnvironmentsProto | None = None,
9495
jax_mesh: jax.sharding.Mesh | None = None,
96+
populate_xla_build_options: bool = False,
9597
) -> compile_options_pb2.CompileOptionsProto:
9698
"""Generates the compilation options for the given compilation environment."""
9799
compile_options = compile_options_pb2.CompileOptionsProto()
98100
executable_build_options = compile_options_pb2.ExecutableBuildOptionsProto()
99101
if compile_environment is not None:
100102
executable_build_options.comp_envs.CopyFrom(compile_environment)
103+
if populate_xla_build_options:
101104
executable_build_options.num_replicas = 1
102105
executable_build_options.num_partitions = 1
103106
executable_build_options.device_ordinal = -1
@@ -168,6 +171,10 @@ def generate_xla_compile_options(
168171
tpu_platform_name = manifest_pb2.Platform.Name(
169172
manifest_pb2.Platform.TPU
170173
).lower()
174+
cuda_platform_name = manifest_pb2.Platform.Name(
175+
manifest_pb2.Platform.CUDA
176+
).lower()
177+
171178
compile_options_map = manifest_pb2.CompileOptionsProtoMap()
172179
if native_serialization_platforms is None:
173180
# If no native serialization platforms are specified, we will set the
@@ -198,24 +205,89 @@ def generate_xla_compile_options(
198205
)
199206

200207
for platform in platforms:
201-
if platform.lower() == tpu_platform_name:
202-
if xla_flags_per_platform:
203-
xla_flags_overrides = xla_flags_per_platform.get(platform, None)
208+
if xla_flags_per_platform:
209+
xla_flags_overrides = xla_flags_per_platform.get(platform, None)
210+
if xla_flags_overrides:
204211
_validate_xla_flags_setting(xla_flags_overrides, persist_xla_flags)
205-
else:
206-
xla_flags_overrides = None
212+
else:
213+
xla_flags_overrides = None
214+
215+
platform_lower = platform.lower()
216+
if platform_lower == tpu_platform_name:
207217
compile_environment = _generate_tpu_compilation_env(xla_flags_overrides)
208218
else:
219+
# CPU Path: Leave as None to preserve legacy portable execution behavior.
220+
# CUDA Path: No specialized compiler environment needed by default.
209221
compile_environment = None
210-
compile_options_map.map[platform.lower()].CopyFrom(
211-
_generate_compilation_options(compile_environment, jax_mesh)
222+
223+
compile_options = _generate_compilation_options(
224+
compile_environment,
225+
jax_mesh,
226+
populate_xla_build_options=(
227+
platform_lower in (tpu_platform_name, cuda_platform_name)
228+
),
212229
)
230+
231+
# Inject env_option_overrides natively for GPU using a dedicated helper.
232+
if platform_lower == cuda_platform_name and xla_flags_overrides:
233+
_apply_gpu_compilation_env_options(compile_options, xla_flags_overrides)
234+
235+
compile_options_map.map[platform_lower].CopyFrom(compile_options)
236+
213237
if not persist_xla_flags:
214238
for compile_options in compile_options_map.map.values():
215239
compile_options.executable_build_options.comp_envs.Clear()
240+
compile_options.env_option_overrides.clear()
216241
return compile_options_map
217242

218243

244+
def _apply_gpu_compilation_env_options(
245+
compile_options: compile_options_pb2.CompileOptionsProto,
246+
xla_flags_overrides: Sequence[str],
247+
) -> None:
248+
"""Applies XLA flag overrides generically for GPU platforms.
249+
250+
Args:
251+
compile_options: The compilation options proto to be modified.
252+
xla_flags_overrides: A sequence of XLA flags to apply as option overrides.
253+
"""
254+
overrides_map = _parse_env_option_overrides_for_gpu(xla_flags_overrides)
255+
for k, v in overrides_map.items():
256+
compile_options.env_option_overrides[k].CopyFrom(v)
257+
258+
259+
def _parse_env_option_overrides_for_gpu(
260+
xla_flags: Sequence[str],
261+
) -> dict[str, compile_options_pb2.OptionOverrideProto]:
262+
"""Parses a list of XLA GPU flags into a dictionary of OptionOverrideProto."""
263+
overrides = {}
264+
for flag in xla_flags:
265+
if not flag.startswith("--"):
266+
raise ValueError(f"Flag {flag} must start with '--'")
267+
268+
# Ensure consistent policy enforcement.
269+
_validate_xla_gpu_flag(flag, strict=True)
270+
271+
key, value = flag[2:].split("=", 1)
272+
override_proto = compile_options_pb2.OptionOverrideProto()
273+
274+
# Infer type (True/False/Int/Float/String)
275+
if value.lower() == "true":
276+
override_proto.bool_field = True
277+
elif value.lower() == "false":
278+
override_proto.bool_field = False
279+
elif value.isdigit() or (value.startswith("-") and value[1:].isdigit()):
280+
override_proto.int_field = int(value)
281+
else:
282+
try:
283+
override_proto.double_field = float(value)
284+
except ValueError:
285+
override_proto.string_field = value
286+
287+
overrides[key] = override_proto
288+
return overrides
289+
290+
219291
def _validate_xla_flags_setting(
220292
xla_flags_overrides: Sequence[str] | None, persist_xla_flags: bool
221293
) -> None:

model/orbax/experimental/model/core/python/compile_options_util_test.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,119 @@ def test_generate_xla_compile_options_xla_flags_no_persist_raise_error(self):
332332
persist_xla_flags=False,
333333
)
334334

335+
def test_generate_xla_compile_options_gpu_flags_env_overrides(self):
336+
compile_options_map = compile_options_util.generate_xla_compile_options(
337+
native_serialization_platforms=['cuda'],
338+
xla_flags_per_platform={
339+
'cuda': [
340+
'--xla_gpu_enable_latency_hiding_scheduler=true',
341+
'--xla_gpu_autotune_level=0',
342+
]
343+
},
344+
persist_xla_flags=True,
345+
)
346+
self.assertIn('cuda', compile_options_map.map)
347+
compile_options = compile_options_map.map['cuda']
348+
349+
overrides = compile_options.env_option_overrides
350+
self.assertIn('xla_gpu_enable_latency_hiding_scheduler', overrides)
351+
self.assertTrue(
352+
overrides['xla_gpu_enable_latency_hiding_scheduler'].bool_field
353+
)
354+
355+
self.assertIn('xla_gpu_autotune_level', overrides)
356+
self.assertEqual(overrides['xla_gpu_autotune_level'].int_field, 0)
357+
358+
def test_generate_xla_compile_options_gpu_flags_experimental_rejection(self):
359+
with self.assertRaisesRegex(
360+
ValueError,
361+
r'XLA GPU compilation flag --xla_gpu_experimental_flag=true is not'
362+
r' supported. Please check field description at'
363+
r' CompilationConfig::xla_gpu_flags',
364+
):
365+
compile_options_util.generate_xla_compile_options(
366+
native_serialization_platforms=['cuda'],
367+
xla_flags_per_platform={'cuda': ['--xla_gpu_experimental_flag=true']},
368+
persist_xla_flags=True,
369+
)
370+
371+
@parameterized.named_parameters(
372+
dict(
373+
testcase_name='bool_true',
374+
flag='--xla_gpu_enable_latency_hiding_scheduler=true',
375+
expected_key='xla_gpu_enable_latency_hiding_scheduler',
376+
expected_field='bool_field',
377+
expected_value=True,
378+
),
379+
dict(
380+
testcase_name='bool_false',
381+
flag='--xla_gpu_enable_latency_hiding_scheduler=false',
382+
expected_key='xla_gpu_enable_latency_hiding_scheduler',
383+
expected_field='bool_field',
384+
expected_value=False,
385+
),
386+
dict(
387+
testcase_name='bool_uppercase_true',
388+
flag='--xla_gpu_enable_latency_hiding_scheduler=TRUE',
389+
expected_key='xla_gpu_enable_latency_hiding_scheduler',
390+
expected_field='bool_field',
391+
expected_value=True,
392+
),
393+
dict(
394+
testcase_name='int_positive',
395+
flag='--xla_gpu_autotune_level=4',
396+
expected_key='xla_gpu_autotune_level',
397+
expected_field='int_field',
398+
expected_value=4,
399+
),
400+
dict(
401+
testcase_name='int_negative',
402+
flag='--xla_gpu_nccl_termination_timeout_seconds=-1',
403+
expected_key='xla_gpu_nccl_termination_timeout_seconds',
404+
expected_field='int_field',
405+
expected_value=-1,
406+
),
407+
dict(
408+
testcase_name='float_positive',
409+
flag='--xla_gpu_auto_spmd_partitioning_memory_budget_ratio=1.5',
410+
expected_key='xla_gpu_auto_spmd_partitioning_memory_budget_ratio',
411+
expected_field='double_field',
412+
expected_value=1.5,
413+
),
414+
dict(
415+
testcase_name='float_negative',
416+
flag='--xla_gpu_auto_spmd_partitioning_memory_budget_ratio=-0.5',
417+
expected_key='xla_gpu_auto_spmd_partitioning_memory_budget_ratio',
418+
expected_field='double_field',
419+
expected_value=-0.5,
420+
),
421+
dict(
422+
testcase_name='string_value',
423+
flag='--xla_gpu_cuda_data_dir=/usr/local/cuda',
424+
expected_key='xla_gpu_cuda_data_dir',
425+
expected_field='string_field',
426+
expected_value='/usr/local/cuda',
427+
),
428+
)
429+
@mock.patch.object(compile_options_util, '_validate_xla_gpu_flag')
430+
def test_generate_xla_compile_options_gpu_flags_type_inference(
431+
self, mock_validate, flag, expected_key, expected_field, expected_value
432+
):
433+
del mock_validate # Unused, just patching for bypass
434+
compile_options_map = compile_options_util.generate_xla_compile_options(
435+
native_serialization_platforms=['cuda'],
436+
xla_flags_per_platform={'cuda': [flag]},
437+
persist_xla_flags=True,
438+
)
439+
self.assertIsNotNone(compile_options_map.map)
440+
build_options_cuda = compile_options_map.map['cuda']
441+
self.assertIn(expected_key, build_options_cuda.env_option_overrides)
442+
override_proto = build_options_cuda.env_option_overrides[expected_key]
443+
with self.subTest('test_oneof_field'):
444+
self.assertEqual(override_proto.WhichOneof('value'), expected_field)
445+
with self.subTest('test_value'):
446+
self.assertEqual(getattr(override_proto, expected_field), expected_value)
447+
335448
@parameterized.named_parameters(
336449
dict(
337450
testcase_name='1d_mesh',

0 commit comments

Comments
 (0)