Skip to content

Commit 824457a

Browse files
author
Orbax Authors
committed
Support the XLA GPU compilation flags in Orbax
PiperOrigin-RevId: 879684156
1 parent 8853f3f commit 824457a

File tree

3 files changed

+230
-7
lines changed

3 files changed

+230
-7
lines changed

export/orbax/export/modules/obm_module_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,36 @@ def test_obm_module_bfloat16_conversion(self, enable_bf16_optimization):
384384
with self.subTest('test_weights_b_dtype'):
385385
self.assertEqual(module.model_params['b'].dtype, expected_dtype)
386386

387+
def test_obm_module_gpu_xla_flags_integration_stable(self):
388+
param_shape = (2, 5)
389+
param_dtype = jnp.dtype(jnp.float32)
390+
param_spec = jax.ShapeDtypeStruct(shape=param_shape, dtype=param_dtype)
391+
model_function_name = 'simple_add'
392+
393+
jax2obm_options = obm_configs.Jax2ObmOptions(
394+
checkpoint_path='checkpoint_path',
395+
native_serialization_platforms=('cuda',),
396+
xla_flags_per_platform={
397+
'cuda': ['--xla_gpu_enable_latency_hiding_scheduler=true']
398+
},
399+
)
400+
401+
orbax_model_module = obm_module.ObmModule(
402+
params=param_spec,
403+
apply_fn={model_function_name: simple_add},
404+
jax2obm_options=jax2obm_options,
405+
)
406+
407+
xla_compile_options_map = (
408+
orbax_model_module.xla_compile_options_per_platform
409+
)
410+
self.assertIsNotNone(xla_compile_options_map)
411+
build_options_cuda = xla_compile_options_map.map['cuda']
412+
self.assertIn(
413+
'xla_gpu_enable_latency_hiding_scheduler',
414+
build_options_cuda.env_option_overrides,
415+
)
416+
387417

388418
class GetSharedValueTest(parameterized.TestCase):
389419

model/orbax/experimental/model/core/python/compile_options_util.py

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .google.protobuf import any_pb2
2828
from .platforms.xla.service.jellyfish import tpu_compilation_environment_pb2 as tpu_comp_env_pb2
2929
from .platforms.xla.service.jellyfish.python import tpu_compilation_environment as tpu_comp_env
30+
from .third_party.neptune.model._src.core import xla_gpu_flag_validation
3031
from tensorflow.compiler.xla import xla_data_pb2 # pylint: disable=g-direct-tensorflow-import
3132
from tensorflow.compiler.xla import xla_pb2 # pylint: disable=g-direct-tensorflow-import
3233
from tensorflow.compiler.xla.pjrt.proto import compile_options_pb2 # pylint: disable=g-direct-tensorflow-import
@@ -165,6 +166,10 @@ def generate_xla_compile_options(
165166
tpu_platform_name = manifest_pb2.Platform.Name(
166167
manifest_pb2.Platform.TPU
167168
).lower()
169+
cuda_platform_name = manifest_pb2.Platform.Name(
170+
manifest_pb2.Platform.CUDA
171+
).lower()
172+
168173
compile_options_map = manifest_pb2.CompileOptionsProtoMap()
169174
if native_serialization_platforms is None:
170175
# If no native serialization platforms are specified, we will set the
@@ -195,24 +200,97 @@ def generate_xla_compile_options(
195200
)
196201

197202
for platform in platforms:
198-
if platform.lower() == tpu_platform_name:
199-
if xla_flags_per_platform:
200-
xla_flags_overrides = xla_flags_per_platform.get(platform, None)
203+
if xla_flags_per_platform:
204+
xla_flags_overrides = xla_flags_per_platform.get(platform, None)
205+
if xla_flags_overrides:
201206
_validate_xla_flags_setting(xla_flags_overrides, persist_xla_flags)
202-
else:
203-
xla_flags_overrides = None
207+
else:
208+
xla_flags_overrides = None
209+
210+
platform_lower = platform.lower()
211+
if platform_lower == tpu_platform_name:
204212
compile_environment = _generate_tpu_compilation_env(xla_flags_overrides)
213+
elif platform_lower == cuda_platform_name:
214+
# GPU Trick: Empty proto to bypass 'None' check and enable jax_mesh
215+
# serialization.
216+
compile_environment = xla_pb2.CompilationEnvironmentsProto()
205217
else:
218+
# CPU Path: Leave as None to preserve legacy portable execution behavior.
206219
compile_environment = None
207-
compile_options_map.map[platform.lower()].CopyFrom(
208-
_generate_compilation_options(compile_environment, jax_mesh)
220+
221+
compile_options = _generate_compilation_options(
222+
compile_environment, jax_mesh
209223
)
224+
225+
# Inject env_option_overrides natively for GPU using a dedicated helper.
226+
if platform_lower == cuda_platform_name and xla_flags_overrides:
227+
_apply_gpu_compilation_env_options(compile_options, xla_flags_overrides)
228+
229+
compile_options_map.map[platform_lower].CopyFrom(compile_options)
230+
210231
if not persist_xla_flags:
211232
for compile_options in compile_options_map.map.values():
212233
compile_options.executable_build_options.comp_envs.Clear()
213234
return compile_options_map
214235

215236

237+
def _apply_gpu_compilation_env_options(
238+
compile_options: compile_options_pb2.CompileOptionsProto,
239+
xla_flags_overrides: Sequence[str],
240+
) -> None:
241+
"""Applies XLA flag overrides generically for GPU platforms.
242+
243+
Args:
244+
compile_options: The compilation options proto to be modified.
245+
xla_flags_overrides: A sequence of XLA flags to apply as option overrides.
246+
"""
247+
overrides_map = _parse_env_option_overrides(xla_flags_overrides)
248+
for k, v in overrides_map.items():
249+
compile_options.env_option_overrides[k].CopyFrom(v)
250+
251+
252+
def _parse_env_option_overrides(
253+
xla_flags: Sequence[str],
254+
) -> dict[str, compile_options_pb2.OptionOverrideProto]:
255+
"""Parses a list of XLA flags into a dictionary of OptionOverrideProto."""
256+
overrides = {}
257+
for flag in xla_flags:
258+
if not flag.startswith('--'):
259+
raise ValueError(f"Flag {flag} must start with '--'")
260+
261+
try:
262+
# Use the C++ ValidateXlaGPUFlag logic to ensure consistent policy
263+
# enforcement across Python and C++ layers.
264+
# The C++ function expects the flag with the '--' prefix.
265+
xla_gpu_flag_validation.validate_xla_gpu_flag(flag, strict=True)
266+
except Exception as e:
267+
# pybind11_abseil appends the status code name to the exception string.
268+
# Remove it to match exactly what users would see from the C++ binaries.
269+
err_msg = str(e)
270+
if err_msg.endswith(' [INVALID_ARGUMENT]'):
271+
err_msg = err_msg.removesuffix(' [INVALID_ARGUMENT]')
272+
raise ValueError(err_msg) from e
273+
274+
key, value = flag[2:].split('=', 1)
275+
override_proto = compile_options_pb2.OptionOverrideProto()
276+
277+
# Infer type (True/False/Int/Float/String)
278+
if value.lower() == 'true':
279+
override_proto.bool_field = True
280+
elif value.lower() == 'false':
281+
override_proto.bool_field = False
282+
elif value.isdigit() or (value.startswith('-') and value[1:].isdigit()):
283+
override_proto.int_field = int(value)
284+
else:
285+
try:
286+
override_proto.double_field = float(value)
287+
except ValueError:
288+
override_proto.string_field = value
289+
290+
overrides[key] = override_proto
291+
return overrides
292+
293+
216294
def _validate_xla_flags_setting(
217295
xla_flags_overrides: Sequence[str] | None, persist_xla_flags: bool
218296
) -> None:

model/orbax/experimental/model/core/python/compile_options_util_test.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,121 @@ def test_generate_xla_compile_options_xla_flags_no_persist_raise_error(self):
328328
persist_xla_flags=False,
329329
)
330330

331+
def test_generate_xla_compile_options_env_overrides(self):
332+
compile_options_map = compile_options_util.generate_xla_compile_options(
333+
native_serialization_platforms=['cuda'],
334+
xla_flags_per_platform={
335+
'cuda': [
336+
'--xla_gpu_enable_latency_hiding_scheduler=true',
337+
'--xla_gpu_autotune_level=0',
338+
]
339+
},
340+
persist_xla_flags=True,
341+
)
342+
self.assertIn('cuda', compile_options_map.map)
343+
compile_options = compile_options_map.map['cuda']
344+
345+
overrides = compile_options.env_option_overrides
346+
self.assertIn('xla_gpu_enable_latency_hiding_scheduler', overrides)
347+
self.assertTrue(
348+
overrides['xla_gpu_enable_latency_hiding_scheduler'].bool_field
349+
)
350+
351+
self.assertIn('xla_gpu_autotune_level', overrides)
352+
self.assertEqual(overrides['xla_gpu_autotune_level'].int_field, 0)
353+
354+
def test_generate_xla_compile_options_gpu_flags_experimental_rejection(self):
355+
with self.assertRaisesRegex(
356+
ValueError,
357+
r'XLA GPU compilation flag --xla_gpu_experimental_flag=true is not'
358+
r' supported. Please check field description at'
359+
r' CompilationConfig::xla_gpu_flags',
360+
):
361+
compile_options_util.generate_xla_compile_options(
362+
native_serialization_platforms=['cuda'],
363+
xla_flags_per_platform={'cuda': ['--xla_gpu_experimental_flag=true']},
364+
persist_xla_flags=True,
365+
)
366+
367+
@parameterized.named_parameters(
368+
dict(
369+
testcase_name='bool_true',
370+
flag='--xla_gpu_enable_latency_hiding_scheduler=true',
371+
expected_key='xla_gpu_enable_latency_hiding_scheduler',
372+
expected_field='bool_field',
373+
expected_value=True,
374+
),
375+
dict(
376+
testcase_name='bool_false',
377+
flag='--xla_gpu_enable_latency_hiding_scheduler=false',
378+
expected_key='xla_gpu_enable_latency_hiding_scheduler',
379+
expected_field='bool_field',
380+
expected_value=False,
381+
),
382+
dict(
383+
testcase_name='bool_uppercase_true',
384+
flag='--xla_gpu_enable_latency_hiding_scheduler=TRUE',
385+
expected_key='xla_gpu_enable_latency_hiding_scheduler',
386+
expected_field='bool_field',
387+
expected_value=True,
388+
),
389+
dict(
390+
testcase_name='int_positive',
391+
flag='--xla_gpu_autotune_level=4',
392+
expected_key='xla_gpu_autotune_level',
393+
expected_field='int_field',
394+
expected_value=4,
395+
),
396+
dict(
397+
testcase_name='int_negative',
398+
flag='--xla_gpu_nccl_termination_timeout_seconds=-1',
399+
expected_key='xla_gpu_nccl_termination_timeout_seconds',
400+
expected_field='int_field',
401+
expected_value=-1,
402+
),
403+
dict(
404+
testcase_name='float_positive',
405+
flag='--xla_gpu_auto_spmd_partitioning_memory_budget_ratio=1.5',
406+
expected_key='xla_gpu_auto_spmd_partitioning_memory_budget_ratio',
407+
expected_field='double_field',
408+
expected_value=1.5,
409+
),
410+
dict(
411+
testcase_name='float_negative',
412+
flag='--xla_gpu_auto_spmd_partitioning_memory_budget_ratio=-0.5',
413+
expected_key='xla_gpu_auto_spmd_partitioning_memory_budget_ratio',
414+
expected_field='double_field',
415+
expected_value=-0.5,
416+
),
417+
dict(
418+
testcase_name='string_value',
419+
flag='--xla_gpu_cuda_data_dir=/usr/local/cuda',
420+
expected_key='xla_gpu_cuda_data_dir',
421+
expected_field='string_field',
422+
expected_value='/usr/local/cuda',
423+
),
424+
)
425+
@mock.patch.object(
426+
compile_options_util.xla_gpu_flag_validation, 'validate_xla_gpu_flag'
427+
)
428+
def test_generate_xla_compile_options_gpu_flags_type_inference(
429+
self, mock_validate, flag, expected_key, expected_field, expected_value
430+
):
431+
del mock_validate # Unused, just patching for bypass
432+
compile_options_map = compile_options_util.generate_xla_compile_options(
433+
native_serialization_platforms=['cuda'],
434+
xla_flags_per_platform={'cuda': [flag]},
435+
persist_xla_flags=True,
436+
)
437+
self.assertIsNotNone(compile_options_map.map)
438+
build_options_cuda = compile_options_map.map['cuda']
439+
self.assertIn(expected_key, build_options_cuda.env_option_overrides)
440+
override_proto = build_options_cuda.env_option_overrides[expected_key]
441+
with self.subTest('test_oneof_field'):
442+
self.assertEqual(override_proto.WhichOneof('value'), expected_field)
443+
with self.subTest('test_value'):
444+
self.assertEqual(getattr(override_proto, expected_field), expected_value)
445+
331446
@parameterized.named_parameters(
332447
dict(
333448
testcase_name='1d_mesh',

0 commit comments

Comments
 (0)